In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import os
import shutil

from torchinfo import summary
import torch.utils.tensorboard as tb

import models_pheno
import mnist

torch.manual_seed(10);
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# import tempfile
# tb_log_dir = tempfile.mkdtemp()
# user = os.getlogin()
# tb_log_dir = f'/tmp/tensorboard/{user}'
# print(tb_log_dir)


In [4]:
# if os.path.exists(tb_log_dir):
#     shutil.rmtree(tb_log_dir)

In [5]:
# logger = tb.SummaryWriter(tb_log_dir)

In [6]:
task = mnist.MNIST()
task.load_all_data(device)

In [22]:
torch.manual_seed(10)
model = models_pheno.SmallNet
net = model().to(device)
summary_kwargs = {'input_size': (task.bs_train, 1, 28, 28), 
                  'col_names': ["input_size", "output_size", "num_params", "kernel_size"]}
summary(net, **summary_kwargs)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
SmallNet                                 --                        --                        --                        --
├─Conv2d: 1-1                            [1000, 1, 28, 28]         [1000, 1, 26, 26]         10                        [1, 1, 3, 3]
├─Conv2d: 1-2                            [1000, 1, 13, 13]         [1000, 1, 11, 11]         10                        [1, 1, 3, 3]
├─Conv2d: 1-3                            [1000, 1, 5, 5]           [1000, 1, 3, 3]           10                        [1, 1, 3, 3]
├─Linear: 1-4                            [1000, 9]                 [1000, 10]                100                       [9, 10]
Total params: 130
Trainable params: 130
Non-trainable params: 0
Total mult-adds (M): 7.34
Input size (MB): 3.14
Forward/backward pass size (MB): 6.53
Params size (MB): 0.00
Estimated Total Size (MB): 9.66

In [8]:
task.perform_stats(net, tqdm=tqdm, device=device);

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 2.320, Accuracy: 9.800%


In [9]:
# opt = torch.optim.SGD(net.parameters(), lr=1e-1)
opt = torch.optim.Adam(net.parameters(), lr=1e-2)

In [10]:
for epoch in tqdm(range(10)):
    for batch_idx, (X_batch, Y_batch) in tqdm(enumerate(task.loader_train), 
                                              leave=False, total=len(task.loader_train)):
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        Y_batch_pred = net(X_batch)
        loss = task.loss_func(Y_batch_pred.log(), Y_batch)
        opt.zero_grad()
        loss.backward()
        opt.step()
    task.perform_stats(net, tqdm=tqdm, device=device)
    
    

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 1.650, Accuracy: 48.000%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 1.042, Accuracy: 67.680%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.832, Accuracy: 73.860%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.772, Accuracy: 75.410%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.720, Accuracy: 77.710%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.700, Accuracy: 78.100%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.682, Accuracy: 78.800%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.671, Accuracy: 79.140%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.665, Accuracy: 79.420%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.655, Accuracy: 79.870%



In [11]:
task.perform_stats(net, loader=task.loader_train, tqdm=tqdm, device=device)
task.perform_stats(net, tqdm=tqdm, device=device)

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

Average Loss: 0.692, Accuracy: 78.308%


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.655, Accuracy: 79.870%


{'loss': 0.655350637435913, 'accuracy': 79.86999999999999}

In [17]:
fitdata = task.calc_pheo_fitness(net, device=device)

In [24]:
data = {model: fitdata}

In [14]:
import util
torch.save(util.model2vec(net), './temp')

SmallNet: (0.6605741858482361, 0.7915)

ConvNet: (0.26250347793102263, 0.9235)

BigConvNet: (0.1045118197798729, 0.9668)


In [6]:
d = {models_pheno.SmallNet: 3.}

In [27]:
torch.save(data, './data/mnist_sgd_eval')