In [None]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '..')
import os
import numpy as np
import pandas as pd
import seaborn as sns

import glob
import re
import matplotlib.pyplot as plt

import torch 
import pytorch_lightning as pl

from systems.memae_autoencoder_system import MemaeSystem
from models.conditional.conditional_memae_mnist import ConditionalMemaeMNIST
from models.memae_mnist import MemaeMNIST
from models.memae_mnist_flat import MemaeMNISTFlat
from models.memae_cifar import MemaeCIFAR

from helpers import select_target, get_sorted_checkpoints, load_memories_from_checkpoints, parse_runs, load_model_change_from_run_df, plot_multi

import matplotlib
OUTPUT_LATEX = True

plt.style.use("default")
if OUTPUT_LATEX:
    matplotlib.use("pdf")
    matplotlib.rcParams.update({
        "pgf.texsystem": "pdflatex",
        'font.family': 'serif',
        'text.usetex': True,
        'pgf.rcfonts': False,
    })


In [None]:
# Specify the directories of the data that should be loaded.
# All subfolders are automatically analyzed
data_dirs = [
]

In [None]:
# Load the data
runs = parse_runs(data_dirs, ["Seed"]) # Always create a seed level
df = load_model_change_from_run_df(runs)
experiment_name = "_".join(filter(lambda x: x not in ["Epoch", "Seed"], df.index.names))
df

In [None]:
mean_seed = df.unstack(level="Seed").mean(axis=1, level=0)
mean_seed

In [None]:
mean_classes = mean_seed.unstack().mean(axis=1, level=1)
plt.figure(figsize=(12,3))
g = sns.heatmap(mean_classes)
g.set_yticklabels(labels=g.get_yticklabels(), va='center')
#if OUTPUT_LATEX:
#    plt.savefig(f"{experiment_name}_memory_change_heatmap_by_epoch.pdf", bbox_inches="tight")
plt.show()

In [None]:
mean_epochs = mean_seed.unstack().sum(axis=1, level=0)
plt.figure(figsize=(12,3))
g = sns.heatmap(mean_epochs)
g.set_yticklabels(labels=g.get_yticklabels(), va='center')
#if OUTPUT_LATEX:
#    plt.savefig(f"{experiment_name}_memory_change_heatmap_by_class.pdf", bbox_inches="tight")
plt.show()

In [None]:
plot_multi(mean_classes, "Memory Change", margin_top=0.001, margin_bottom=0, bottom=0, save=True)

In [None]:

margin_top=0.0003
margin_bottom=0
bottom=0
top=None
save=True
filename="Memory Change"

df = mean_classes.T
df.index += 1

min_score = df.min().min() - margin_bottom
max_score = df.max().max() + margin_top

if top is not None:
    max_score = top
if bottom is not None:
    min_score = bottom


plt.figure(figsize=(4,3))
ax = plt.gca()
ax.set_ylabel("Mean Memory Change")
ax.set_ylim([min_score, max_score])
ax.set_xlim([0,100])
df.plot(ax=ax)

if save:
    if not filename:
        filename = f"{datetime.now().strftime('%Y-%m-%d %H-%M-%S-%f')}{random.randint(0, 100000)}"

    plt.savefig(
        f"{filename}.pdf",
        bbox_inches="tight",
    )

In [None]:
mean_classes