In [1]:
#do necessary imports
%load_ext autoreload
%autoreload 2
from pathlib import Path
import sys
import torch
import json
import torch.nn as nn
import torch.optim as optim
from rean.utils import make_run_dir, to_serializable
from rean.data.Dataset import make_datasets
from rean.models.CNN import PlainCNN
from rean.models.P4 import P4CNN
from rean.models.RelaxedP4 import RelaxedP4CNN
from rean.train import train_full, evaluate
from rean.plot import LossPlot
import matplotlib.pyplot as plt


In [None]:
#some params that should be the same across all experimental runs

group_order = 4
hidden_dim = 20 #from cohen
out_channels = hidden_dim
classes = 10
kernel_size = 3
num_gconvs = 6 #from cohen
num_epochs = 40
batch_size = 64 #leave this
learning_rate = 0.002 #0.002 is final
gamma = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

full loop
models = ['PlainCNN', 'P4CNN', 'RelaxedP4CNN']
noise_types = ['none', 'iso', 'aniso']
stds = [0.1, 0.2, 0.3]

In [None]:
# do the same for each noise type, but plotting both std runs on the same axes
#so, each axes object will have 6 lines: 3 models x 2 stds
noise_types = ['iso', 'aniso']
for noise_type in noise_types:
    fig, (train_ax, val_ax) = plt.subplots(1,2)
    runs = []
    tests = []
    for model_name in models:
        for std in stds:
            runpath = Path(f"./runs/{model_name}_{noise_type}_std{std}_gamma{gamma}_lr{learning_rate}")

            #load the rundata in from the json
            rundatapath = runpath / "run_data.json"
            testdatapath = runpath / "test_data.json"
            with rundatapath.open("r", encoding="utf-8") as f:
                run_data = json.load(f)
            with testdatapath.open("r", encoding="utf-8") as f:
                test_data = json.load(f)
            run_data.update(test_data)
            runs.append(run_data)
            tests.append(test_data)

    #make loss plots
    train_ax = LossPlot(train_ax, runs, labels = ["model_name", "std"], title = "Training Loss", val = False)
    val_ax = LossPlot(val_ax, runs, labels = ['model_name', "std"], title = "Validation Loss", train = False)
    fig.tight_layout()
    fig.suptitle(f"Training and Validation Loss, {noise_type.capitalize()} Noise Experiments", y=1.02)
    fig.savefig(f"{noise_type}_loss_plots.png")
