In [5]:
import sys
sys.path.append('../')

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
import glob
import io
import os
import pickle


import ColorDataUtils.mattplotlib as mplt
from ColorDataUtils.mattprintlib import fmt
from NDNT.utils import imagesc   # because I'm lazy

class Model:
    def __init__(self, ndn_model, LLs, trial):
        self.ndn = ndn_model
        self.LLs = LLs
        self.trial = trial
from models import iter_core, cnn_core
from models import iter_0715_1x, iter_0722_1x, iter_0801_1x, iter_0808_1x
from models import iter_0715_2x, iter_0722_2x, iter_0801_2x, iter_0808_2x
from models import cnn_0715_1x, cnn_0722_1x, cnn_0801_1x, cnn_0808_1x
from models import cnn_0715_2x, cnn_0722_2x, cnn_0801_2x, cnn_0808_2x

device = torch.device("mps")

class CPUUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
# # load all the models in the models directory
# models_dirname = '../models'
# models = {}
# for dirname in os.listdir(models_dirname):
#     if dirname.startswith('.'):
#         continue
#     for fname in glob.glob(os.path.join(models_dirname, dirname, '*.pkl')):
#         if os.path.basename(fname) in ['pred.pkl', 'study.pkl']:
#             continue
#         with open(fname, 'rb') as f:
#             fname = fname.split('/')[2] + '_' + fname.split('/')[3]
#             print(fname)
#             try:
#                 models[fname] = CPUUnpickler(f).load()
#             except EOFError as e:
#                 print(e)
#                 print('skipping', fname)
#
# LLs = []
# for name, model in models.items():
#     if 'ndn_model' in model.__dict__:
#         LLs.append(np.mean(model.LLs))
#     else:
#         print(name)
#
# fig = plt.figure(figsize=(10, 5))
# plt.plot(LLs)
# plt.show()

In [None]:
# # save the projection layer from all of the models to images that I can look through
# k = 0
# for name, model in models.items():
#     if 'ndn_model' in model.__dict__:
#         if model.ndn_model is not None:
#             print(name)
#             fig = plt.figure(figsize=(15,10))
#             grid = matplotlib.gridspec.GridSpec(10, 8, wspace=0.1, hspace=0.1)
#             for idx in range(10*4):
#                 i,j = np.unravel_index(idx,(10,4)) # layer position
#                 row,col = np.unravel_index(idx,(10,8)) # plot position
#                 ax = plt.subplot(grid[row,col])
#                 proj_weights = model.ndn_model.networks[0].layers[1].get_weights()[j,:,:,i]
#                 proj_filter = model.ndn_model.networks[0].layers[1].get_weights()[j,:,:,i]
#                 imagesc(proj_filter, ax=ax, cmap='viridis')
#             # save the figure
#             plt.savefig('proj_filters/{}.png'.format(name), bbox_inches='tight')
#     k += 1

In [7]:
good_model_fnames = [
    'cnns_multi_04/cnn_0.pkl',
    'cnns_multi_02/cnn_1.pkl',
    'cnns_06/cnn_14.pkl',
    'cnns_06/cnn_11.pkl',
    'cnns_06/cnn_6.pkl',
    'cnns_05/cnn_5.pkl',
    'cnns_03/cnn_17.pkl',
    'cnns_05/cnn_11.pkl',
    'cnns_03/cnn_4.pkl',
    'cnns_02/cnn_19.pkl']

In [8]:
good_models = []
for good_model_fname in good_model_fnames:
    with open('../models/'+good_model_fname, 'rb') as f:
        good_models.append(CPUUnpickler(f).load())

In [44]:
#good_models[0].ndn_model.networks[0].layers[0].__dict__['_modules']['reg'].__dict__['vals']
good_models[1].ndn_model.networks[0].layers[1].__dict__

{'window': True,
 'tent_basis': None,
 'training': False,
 '_parameters': OrderedDict([('weight',
               Parameter containing:
               tensor([[ 0.0866, -0.0072,  0.0346,  ..., -0.0408, -0.0420,  0.1439],
                       [ 0.0950,  0.0088,  0.0569,  ..., -0.0291, -0.0177,  0.1793],
                       [ 0.0779,  0.0576,  0.0590,  ...,  0.0192,  0.0348,  0.2232],
                       ...,
                       [ 0.0066,  0.0452,  0.0670,  ..., -0.1348, -0.0619,  0.0784],
                       [ 0.0168,  0.0122,  0.0566,  ..., -0.0713, -0.0582,  0.0201],
                       [ 0.0371, -0.0113,  0.0281,  ..., -0.0504, -0.0358, -0.0107]],
                      requires_grad=True))]),
 '_buffers': OrderedDict([('bias',
               tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
    

In [41]:
def model_summary(model):
    def type_name(obj):
        return str(type(obj)).split('.')[-1].split("'")[0]

    print('LL: {}'.format(np.mean(model.LLs)))
    
    for network in model.ndn_model.networks:
        print('{}'.format(type_name(network)))
        for layer in network.layers:
            print('\t{}\t{}\t{}\t{}'.format(
                                    type_name(layer), 
                                    layer.get_weights().shape,
                                    layer.
                                    layer.__dict__['_modules']['reg'].__dict__['vals']))

# look at the parameters for the good models
#print(model_summary(good_models[0]))

for model in good_models:
    print('----------------------------------------')
    model_summary(model)

----------------------------------------
LL: 0.02847587689757347
ScaffoldNetwork
	STconvLayer	(7, 7, 11, 4)	{'d2x': 1e-06, 'd2t': 0.01, 'edge_t': 100}
	ConvLayer	(4, 21, 21, 38)	{}
	IterLayer	(38, 9, 9, 38)	{}
ReadoutNetwork
	ReadoutLayer	(228, 587)	{'max': 0.0001}
FFnetwork
	NDNLayer	(166, 587)	{'d2t': 0.5}
FFnetwork
	ChannelLayer	(587,)	{}
----------------------------------------
LL: 0.027293842285871506
ScaffoldNetwork
	STconvLayer	(9, 9, 14, 4)	{'d2x': 1e-06, 'd2t': 0.01, 'edge_t': 100}
	ConvLayer	(4, 17, 17, 48)	{}
	ConvLayer	(48, 15, 15, 48)	{}
	ConvLayer	(48, 9, 9, 48)	{}
	ConvLayer	(48, 5, 5, 48)	{}
ReadoutNetwork
	ReadoutLayer	(192, 585)	{'max': 0.0001}
FFnetwork
	NDNLayer	(166, 585)	{'d2t': 0.5}
FFnetwork
	ChannelLayer	(585,)	{}
----------------------------------------
LL: 0.03348736653382751
ScaffoldNetwork
	STconvLayer	(7, 7, 11, 4)	{'d2x': 1e-06, 'd2t': 0.01, 'edge_t': 100}
	ConvLayer	(4, 13, 13, 44)	{}
	STconvLayer	(44, 17, 17, 2, 35)	{}
	STconvLayer	(35, 17, 17, 2, 35)	{