In [None]:
import wandb
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from pprint import pprint
sns.set_style('darkgrid')

api = wandb.Api()

# Learning curves
This notebook creates the learning curve plots seen in the paper. It requires stored wandb policy model runs. Set the below variables to the correct wandb entity and project. Using the wandb API we will create a dictionary that stores the runs we want to do computations for. Make sure to check that this dictionary contains the correct runs, and write the correct filters if this is not the case.

In [None]:
dataset = 'knee'  # or 'brain'
wandb_entity = 'WANDB_ENTITY NAME'
wandb_project = 'WANDB_PROJECT_NAME'

# Partitions to plot learning curves for
partitions = ['train', 'val']

In [None]:
# This creates a dictionary of run names, dirs and ids, based on the Wandb API.

if dataset.lower() == 'knee':
    sample_rate = 0.5
if dataset.lower() == 'brain':
    sample_rate = 0.2

run_id_dict = {"16-32": defaultdict(dict),
               "4-32": defaultdict(dict)}

runs = api.runs(f"{wandb_entity}/{wandb_project}", {"config.sample_rate": sample_rate})
for run in runs:
    if not run.state == 'finished':
        continue
    
    name = run.name
    args = run.config
    
    if args['dataset'].lower() != dataset.lower():
        continue  # Skip models not on given dataset
        
    ### YOUR FILTERS HERE ###

    if args['model_type'] == 'greedy':
        key = 'greedy'
    elif args['gamma'] == 1:
        key = 'nongreedy'
    else:
        key = args['gamma']
                
    run_dir = args['run_dir'].split('/')[-1]
    
    if args['accelerations'] == [8]:
        run_id_dict["16-32"][key][name] = {'id': run.id, 'dir': run_dir}
    elif args['accelerations'] == [32]:
        run_id_dict["4-32"][key][name] = {'id': run.id, 'dir': run_dir}
            
pprint(run_id_dict)

In [None]:
def get_learning_curve(runs, entity, project, final_step, partition):
    curves = []
    for run_name, run_data in runs.items():
        run_id = run_data['id']
        run = api.run(f"{entity}/{project}/{run_id}")
        try:
            vals = run.history()[f'{partition}_ssims.{final_step}'][:50]
            curves.append(np.array(vals))
        except KeyError:
            print(partition, run_name)
            continue
    
    curves = np.stack(curves)    
    return np.mean(curves, axis=0), np.std(curves, axis=0, ddof=1)

def get_curve_statistics(run_dict, horizon, partition, dataset, entity, wandb_project):
    if dataset == 'Knee':
        greedy_runs = run_dict[hor]['greedy']
        ngreedy_runs = run_dict[hor]['nongreedy']
        g09_runs = run_dict[hor][0.9]

        gmean, gstd = get_learning_curve(greedy_runs, entity, wandb_project, final_step, partition)
        ngmean, ngstd = get_learning_curve(ngreedy_runs, entity, wandb_project, final_step, partition)
        g09mean, g09std = get_learning_curve(g09_runs, entity, wandb_project, final_step, partition)

        return ((gmean, gstd), (ngmean, ngstd), (g09mean, g09std))

    if dataset == 'Brain':
        greedy_runs = run_dict[hor]['greedy']
        ngreedy_runs = run_dict[hor]['nongreedy']
        g09_runs = run_dict[hor][0.9]
    
        gmean, gstd = get_learning_curve(greedy_runs, entity, wandb_project, final_step, partition)
        ngmean, ngstd = get_learning_curve(ngreedy_runs, entity, wandb_project, final_step, partition)
        g09mean, g09std = get_learning_curve(g09_runs, entity, wandb_project, final_step, partition)

        return ((gmean, gstd), (ngmean, ngstd), (g09mean, g09std))

def plot_learning_curves(vals, labels, cdict, final_step, num, partition, partitions, dataset):
    if partition == 'train':
        num += 2
        
    plt.subplot(len(partitions), 2, num + 1)
    
    # Val base
    if num == 0:
        if dataset == 'Knee':
            ylims = (.7135, 0.7175)
        else:
            ylims = (.9135, 0.9152)
            plt.xlabel('epoch', fontsize=15)
        plt.title('Base horizon', fontsize=18)
        plt.ylabel('val SSIM', fontsize=15)
    # Val long
    elif num == 1:
        if dataset == 'Knee':
            ylims = (.736, 0.742)  
        else:
            ylims = (.865, 0.905)
            plt.xlabel('epoch', fontsize=15)
        plt.title('Long horizon', fontsize=18)
    # Train base
    elif num == 2:
        if dataset == 'Knee':
            ylims = (.722, 0.7325)  
        plt.xlabel('epoch', fontsize=15)
        plt.ylabel('train SSIM', fontsize=15)
    # Train long
    elif num == 3:
        if dataset == 'Knee':
            ylims = (.739, 0.751)  
        plt.xlabel('epoch', fontsize=15)

    t = list(range(vals[0][0].shape[-1]))
    for i, val in enumerate(vals):
        label = labels[i]
        means = val[0]
        stds = val[1]
        plt.plot(t, means, label=f'{label}', c=cdict[label])
        plt.fill_between(t, means-stds, means+stds, alpha=.3, color=cdict[label])
    plt.legend(loc='upper left')
    plt.ylim(*ylims)

In [None]:
cdict = {'Greedy': 'tab:blue', 'NGreedy': 'tab:cyan', 'γ = 0.9': 'tab:orange'}
labels = ['Greedy', 'NGreedy', 'γ = 0.9']
plt.figure(figsize=(18, 3.5 * len(partitions)))
for i, horizon in enumerate(['base', 'long']):
    if horizon == 'base':
        hor = '16-32'
        final_step = 16
    elif horizon == 'long':
        hor = '4-32'
        final_step = 28
    else:
        raise ValueError('Unknown value for horizon.')
    
    for partition in partitions:
        vals = get_curve_statistics(run_id_dict, horizon, partition, dataset, wandb_entity, wandb_project)    
        plot_learning_curves(vals, labels, cdict, final_step, i, partition, partitions, dataset)

plt.suptitle(f'{dataset} learning curves', fontsize=21)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()