In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import os
import shutil

from torchinfo import summary
import torch.utils.tensorboard as tb

import models_pheno
import mnist

torch.manual_seed(10);
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# import tempfile
# tb_log_dir = tempfile.mkdtemp()
# user = os.getlogin()
# tb_log_dir = f'/tmp/tensorboard/{user}'
# print(tb_log_dir)


In [4]:
# if os.path.exists(tb_log_dir):
#     shutil.rmtree(tb_log_dir)

In [5]:
# logger = tb.SummaryWriter(tb_log_dir)

In [6]:
task = mnist.MNIST()
task.load_all_data(device)

In [34]:
torch.manual_seed(10)
model = models_pheno.ConvNet
net = model().to(device)
summary_kwargs = {'input_size': (task.bs_train, 1, 28, 28), 
                  'col_names': ["input_size", "output_size", "num_params", "kernel_size"]}
summary(net, **summary_kwargs)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
ConvNet                                  --                        --                        --                        --
├─Conv2d: 1-1                            [1000, 1, 28, 28]         [1000, 2, 26, 26]         20                        [1, 2, 3, 3]
├─Conv2d: 1-2                            [1000, 2, 13, 13]         [1000, 5, 11, 11]         95                        [2, 5, 3, 3]
├─Conv2d: 1-3                            [1000, 5, 5, 5]           [1000, 10, 3, 3]          460                       [5, 10, 3, 3]
├─Linear: 1-4                            [1000, 10]                [1000, 10]                110                       [10, 10]
Total params: 685
Trainable params: 685
Non-trainable params: 0
Total mult-adds (M): 27.21
Input size (MB): 3.14
Forward/backward pass size (MB): 16.46
Params size (MB): 0.00
Estimated Total Size (MB): 19.59

In [35]:
task.perform_stats(net, tqdm=tqdm, device=device);

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 2.321, Accuracy: 7.670%


In [36]:
# opt = torch.optim.SGD(net.parameters(), lr=1e-1)
opt = torch.optim.Adam(net.parameters(), lr=1e-2)

In [37]:
for epoch in tqdm(range(10)):
    for batch_idx, (X_batch, Y_batch) in tqdm(enumerate(task.loader_train), 
                                              leave=False, total=len(task.loader_train)):
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        Y_batch_pred = net(X_batch)
        loss = task.loss_func(Y_batch_pred, Y_batch)
        opt.zero_grad()
        loss.backward()
        opt.step()
    task.perform_stats(net, tqdm=tqdm, device=device)
    
    

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.643, Accuracy: 80.920%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.440, Accuracy: 87.020%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.384, Accuracy: 88.530%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.347, Accuracy: 89.300%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.344, Accuracy: 90.050%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.318, Accuracy: 90.380%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.292, Accuracy: 91.480%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.292, Accuracy: 91.360%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.271, Accuracy: 92.050%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.263, Accuracy: 92.350%



In [38]:
task.perform_stats(net, loader=task.loader_train, tqdm=tqdm, device=device)
task.perform_stats(net, tqdm=tqdm, device=device)

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

Average Loss: 0.284, Accuracy: 91.655%


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.263, Accuracy: 92.350%


{'loss': 0.26250348091125486, 'accuracy': 92.35}

In [39]:
fitdata = task.calc_pheo_fitness(net, n_sample=40000, device=device)

In [40]:
import util
try:
    data = torch.load('./data/mnist_sgd_eval')
    weights = torch.load('./data/mnist_sgd_weights')
except:
    data = {}
    weights = {}
data.update({model: fitdata})
weights.update({model: util.model2vec(net)})
torch.save(data, './data/mnist_sgd_eval')
torch.save(weights, './data/mnist_sgd_weights')

In [45]:
for key, val in data.items():
    print(key.__name__, ': ', val)

BigConvNet :  {'fitness': -0.10598196089267731, 'loss': 0.10598196089267731, 'accuracy': 96.65}
SmallNet :  {'fitness': -0.6912355422973633, 'loss': 0.6912355422973633, 'accuracy': 78.32249999999999}
ConvNet :  {'fitness': -0.28138837218284607, 'loss': 0.28138837218284607, 'accuracy': 91.6425}
