# Test plots

Please insert the models that you want to visualize in `filenames` and `ls_filenames`. The former should contain the files corresponding to the test results on the $60 \times 4$ lattice, whereas the latter should contain the test results on different lattice sizes. Note that architectures that contain a flattening step are confined to a single lattice size.

In [None]:
import pickle, os
import numpy as np, matplotlib.pyplot as plt

path = 'test_pickles'
filenames = ['test_ref.pickle']

losses = []
MSEs = []
results = []
train_sample_numbers = []

for filename in filenames:
    with open(os.path.join(path, filename), 'rb') as file:
        print('loading ' + filename)
        loss, MSE, result, train_sample_number = pickle.load(file)
        losses.append(loss)
        MSEs.append(MSE)
        results.append(result)
        train_sample_numbers.append(train_sample_number)

losses = np.array(losses)
MSEs = np.array(MSEs)
train_sample_numbers = np.array(train_sample_numbers)

ls_filenames = ['ls_test_ref.pickle']

ls_losses = []
ls_MSEs = []
ls_dims = []
for ls_filename in ls_filenames:
    with open(os.path.join(path, ls_filename), 'rb') as file:
        print('loading ' + ls_filename)
        ls_loss, ls_MSE, ls_dim = pickle.load(file)
        ls_losses.append(ls_loss)
        ls_MSEs.append(ls_MSE)
        ls_dims.append(ls_dim)
        
ls_losses = np.array(ls_losses)
ls_MSEs = np.array(ls_MSEs)

In [None]:
mean_losses = np.mean(losses, axis=2)
min_losses = np.min(losses, axis=2)
max_losses = np.max(losses, axis=2)
std_losses = np.std(losses, axis=2)
mean_MSEs = np.mean(MSEs, axis=2)
min_MSEs = np.min(MSEs, axis=2)
max_MSEs = np.max(MSEs, axis=2)
std_MSEs = np.std(MSEs, axis=2)

mean_ls_losses = np.mean(ls_losses, axis=2)
min_ls_losses = np.min(ls_losses, axis=2)
max_ls_losses = np.max(ls_losses, axis=2)
std_ls_losses = np.std(ls_losses, axis=2)
mean_ls_MSEs = np.mean(ls_MSEs, axis=2)
min_ls_MSEs = np.min(ls_MSEs, axis=2)
max_ls_MSEs = np.max(ls_MSEs, axis=2)
std_ls_MSEs = np.std(ls_MSEs, axis=2)

In [None]:
colors = ['g', 'b', 'r', 'c', 'm', 'y', 'k']
labels = [filenames[i][:-7] for i in range(len(filenames))]
ls_labels = [ls_filenames[i][:-7] for i in range(len(ls_filenames))]
fontsize = 14
alpha = 0.125

fig, axs = plt.subplots(nrows=1, ncols=2, figsize = (15, 4), sharey = True, gridspec_kw = {'wspace':0.03, 'hspace':0})

for j, (col, label) in enumerate(zip(colors, labels)):
    axs[0].plot(train_sample_numbers[j], min_losses[j], col, label=label, marker='.')
    axs[0].plot(train_sample_numbers[j], max_losses[j], col, marker='.')
    axs[0].plot(train_sample_numbers[j], mean_losses[j], col+'--', marker='.')
    axs[0].fill_between(train_sample_numbers[j], min_losses[j], max_losses[j], facecolor=col, alpha=alpha)
    axs[0].fill_between(train_sample_numbers[j], np.quantile(losses[j], q=0.2, axis=1), np.quantile(losses[j], q=0.8, axis=1), facecolor=col, alpha=alpha)
    axs[0].fill_between(train_sample_numbers[j], np.quantile(losses[j], q=0.4, axis=1), np.quantile(losses[j], q=0.6, axis=1), facecolor=col, alpha=alpha)

for j, (col, ls_label) in enumerate(zip(colors, ls_labels)):
    axs[1].plot(np.arange(len(ls_dims[j])), min_ls_losses[j], col, label=ls_label, marker='.')
    axs[1].plot(np.arange(len(ls_dims[j])), max_ls_losses[j], col, marker='.')
    axs[1].plot(np.arange(len(ls_dims[j])), mean_ls_losses[j], col+'--', marker='.')
    axs[1].fill_between(np.arange(len(ls_dims[j])), min_ls_losses[j], max_ls_losses[j], facecolor=col, alpha=alpha)
    axs[1].fill_between(np.arange(len(ls_dims[j])), np.quantile(ls_losses[j], q=0.2, axis=1), np.quantile(ls_losses[j], q=0.8, axis=1), facecolor=col, alpha=alpha)
    axs[1].fill_between(np.arange(len(ls_dims[j])), np.quantile(ls_losses[j], q=0.4, axis=1), np.quantile(ls_losses[j], q=0.6, axis=1), facecolor=col, alpha=alpha)

axs[0].set_ylabel('test loss', fontsize=fontsize)
axs[0].set_ylim(1e-8, 1e-1)
axs[0].set_yscale('log')
axs[0].set_xscale('log')
axs[0].set_xlabel('training samples', fontsize=fontsize)

for i in range(1):
    axs[i].grid(alpha=0.7, linewidth=0.5)
    axs[i].legend()

max_dims = np.argmax([len(ls_dims[i]) for i in range(len(ls_dims))])
axs[1].set_xticks(np.arange(len(ls_dims[max_dims])))
axs[1].set_xticklabels(ls_dims[max_dims])
axs[1].set_xlabel('lattice size', fontsize=fontsize)

plt.show()