# Generating Plots

This notebook shows how to generate accuracy plots.

It is intended to be executed after the `train.py` script finishes execution and outputs the logs to the `logs` folder for your experiment.

Make sure the `exp_name` and `episode_name` below are set correctly.

In [None]:
from plot import average_repeats, compute_forgetting, plot_forgetting
from plot import get_baselines, compute_intransigence, add_switch_pts
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import numpy as np
import matplotlib.gridspec as grd

Specify experiemnt name and episode as defined in `train.py`

In [None]:
exp_name = 'demo'
episode_name = 'c100-2'

Define some parameters to configure the figures

In [None]:
p_key = {'Vanilla': '-o',
         'L2': '-v',
         'EWC': '-x',
         'RWalk': '-^'}

c_key = {'Vanilla': 'orchid',
         'L2': 'lightskyblue',
         'EWC': 'lightcoral',
         'RWalk': 'limegreen'}

These paths should correspond to the log files generated after running `train.py`

In [None]:
trials = {
    episode_name : {
        'baseline': Path('logs') / exp_name / 'baselines' / episode_name,
        'Vanilla': Path('logs') / exp_name / 'vanilla' / episode_name / 'main',
        'L2': Path('logs') / exp_name / 'l2' / episode_name / 'main',
        'EWC': Path('logs') / exp_name / 'ewc' / episode_name / 'main',
        'RWalk': Path('logs') / exp_name / 'rwalk' / episode_name / 'main'
    },
}



Generate the accuracy plot.

In [None]:

fontsize = 12
plt.figure(figsize=(8, 8))
bmetrics, _, _ = average_repeats(trials[episode_name]['baseline'])
bacc = get_baselines(bmetrics)

for key, fname in trials[episode_name].items():
    if key == 'baseline':
        continue
    metrics, mxs, spts = average_repeats(fname)
    forgetting = compute_forgetting(metrics, mxs)

    plt.plot(*zip(*metrics['T0-avg-MH']), p_key[key],
                    label=key, markevery=6, color=c_key[key])

add_switch_pts(spts)
plt.xlabel('Batches seen', fontsize=fontsize)
plt.ylabel('Accuracy', fontsize=fontsize)
plt.ylim([0.5, 1.])

plt.legend(loc='upper center', fontsize=fontsize)