In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# path
import os
from os.path import join
cwd = os.getcwd()
parts = cwd.split('/scripts/mnist')
ROOT = parts[0]
os.chdir(ROOT)
import sys
sys.path.insert(0, ROOT)

RES  = join(ROOT,'data', 'mnist_shuffled', 'results')
figs_folder = join(RES, 'figs')
if not os.path.isdir(figs_folder):
    os.mkdir(figs_folder)

In [None]:
import pickle
from time import time
import torch
import numpy as np
np.seed = 1101
from matplotlib import pyplot as plt

# Training

In [None]:
file = open(join(RES, 'training_data'), 'rb')
T = pickle.load(file)
train_loss = np.array(T['train_loss'])
test_loss = np.array(T['test_loss'])
test_acc = np.array(T['test_acc'])
train_acc = np.array(T['train_acc'])

In [None]:
test_acc[-1]

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(train_loss,'-', label='train')
plt.plot(test_loss,'-', label='test')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.title('loss')
plt.subplot(1,2,2)
plt.plot(train_acc,'-', label='train')
plt.plot(test_acc,'-', label='test')
plt.xlabel('epoch')
plt.ylabel('accuracy (% correct)')
plt.legend()
plt.title('accuracy')
plt.savefig(join(RES, 'figs', 'training_data.png') )

# Intrinsic Dimension

In [None]:
layers_all = ['input','conv1','pool1','conv2','pool2','d1','output']
ID_orig = np.load(join(ROOT,'data','mnist','results','ID_all.npy'))
ID_shuf = np.load(join(ROOT,'data','mnist_shuffled','results','ID_all.npy'))
ID_grad = np.load(join(ROOT,'data','mnist_grad','results','ID_all.npy'))

In [None]:
fig=plt.figure(figsize=(10,10))
plt.errorbar(range(7),ID_orig[:,0],yerr=ID_orig[:,1],fmt='-ko')
plt.errorbar(range(7),ID_shuf[:,0],yerr=ID_shuf[:,1],fmt='-ro')
plt.errorbar(range(7),ID_grad[:,0],yerr=ID_grad[:,1],fmt='-bo')
plt.xticks(range(7),layers_all,rotation=90)


plt.savefig(join(RES, 'figs', 'all_mnists.svg'))
plt.savefig(join(RES, 'figs', 'all_mnists.png'))
plt.savefig(join(RES, 'figs', 'all_mnists.eps'))

plt.show()

In [None]:
idx = [0,2,4,5,6]
layers_selection = ['input','pool1','pool2','d1','output']

In [None]:
ID_orig

In [None]:
fs = 30
ms = 20
lw = 2


fig = plt.figure(figsize=(10,10))
plt.errorbar(range(5), ID_orig[idx,0], yerr=ID_orig[idx,1], fmt='-ko',label='ID', 
             linewidth=lw,markersize=ms)
     
plt.xticks(range(len(layers_selection)), layers_selection, rotation='vertical',fontsize=fs )
plt.yticks(fontsize=fs)
plt.xlabel('layers',fontsize=fs)
plt.ylabel('ID',fontsize=fs)
plt.legend(fontsize=fs)

plt.savefig(join(RES, 'figs', 'id_layers.svg'))
plt.savefig(join(RES, 'figs', 'id_layers.png'))
plt.savefig(join(RES, 'figs', 'id_layers.eps'))

### Block Analysis

In [None]:
BA = np.load(join(ROOT,'data','mnist','results','BA.npy'))
n_points = BA[0][2]

In [None]:
fs = 30
ms = 10
lw = 1
import itertools
sel = [list(range(4)),[5],[9],[15]]
merged = list(itertools.chain.from_iterable(sel))
nps = [n_points[x] for x in merged]

In [None]:
fig = plt.figure(figsize=(10,5))
for i in range(4):
    plt.errorbar(n_points,BA[i][0],
                 yerr=BA[i][1],
                 fmt='-k',
                 #label=layers[i + 1],
                 linewidth=lw,
                 markersize=ms)

#for i in range(4):
#    plt.text(1500,BA[i][0][0],layers[i+1],fontsize=fs)
    
#plt.title('Block analysis')
plt.ylabel('ID',fontsize=fs)
plt.yticks(fontsize=fs)
plt.xticks(nps, nps, rotation=90, fontsize=fs)
plt.xlabel('n. of points',fontsize=fs)

plt.savefig(join(RES, 'figs', 'ba.svg'))
plt.savefig(join(RES, 'figs', 'ba.png'))
plt.savefig(join(RES, 'figs', 'ba.eps'))