In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import os
import shutil

from torchinfo import summary
import torch.utils.tensorboard as tb

import models_pheno
import mnist

torch.manual_seed(10);
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# import tempfile
# tb_log_dir = tempfile.mkdtemp()
# user = os.getlogin()
# tb_log_dir = f'/tmp/tensorboard/{user}'
# print(tb_log_dir)


In [4]:
# if os.path.exists(tb_log_dir):
#     shutil.rmtree(tb_log_dir)

In [5]:
# logger = tb.SummaryWriter(tb_log_dir)

In [6]:
task = mnist.MNIST()
task.load_all_data(device)

In [7]:
torch.manual_seed(10)
net = models_pheno.BigConvNet().to(device)
summary_kwargs = {'input_size': (task.bs_train, 1, 28, 28), 
                  'col_names': ["input_size", "output_size", "num_params", "kernel_size"]}
summary(net, **summary_kwargs)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
BigConvNet                               --                        --                        --                        --
├─Conv2d: 1-1                            [1000, 1, 28, 28]         [1000, 10, 26, 26]        100                       [1, 10, 3, 3]
├─Conv2d: 1-2                            [1000, 10, 13, 13]        [1000, 10, 11, 11]        910                       [10, 10, 3, 3]
├─Conv2d: 1-3                            [1000, 10, 5, 5]          [1000, 10, 3, 3]          910                       [10, 10, 3, 3]
├─Linear: 1-4                            [1000, 10]                [1000, 10]                110                       [10, 10]
├─Linear: 1-5                            [1000, 10]                [1000, 10]                110                       [10, 10]
Total params: 2,140
Trainable params: 2,140
Non-trainable params: 0
Total mult-adds (M): 

In [8]:
task.perform_stats(net, tqdm=tqdm, device=device);

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 2.325, Accuracy: 9.240%


In [9]:
# opt = torch.optim.SGD(net.parameters(), lr=1e-1)
opt = torch.optim.Adam(net.parameters(), lr=1e-2)

In [10]:
for epoch in tqdm(range(10)):
    for batch_idx, (X_batch, Y_batch) in tqdm(enumerate(task.loader_train), 
                                              leave=False, total=len(task.loader_train)):
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        Y_batch_pred = net(X_batch)
        loss = task.loss_func(Y_batch_pred.log(), Y_batch)
        opt.zero_grad()
        loss.backward()
        opt.step()
    task.perform_stats(net, tqdm=tqdm, device=device)
    
    

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.377, Accuracy: 88.060%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.211, Accuracy: 93.710%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.182, Accuracy: 94.610%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.156, Accuracy: 95.360%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.128, Accuracy: 96.080%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.117, Accuracy: 96.630%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.112, Accuracy: 96.590%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.113, Accuracy: 96.410%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.106, Accuracy: 96.620%


HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.099, Accuracy: 96.950%



In [11]:
task.perform_stats(net, loader=task.loader_train, tqdm=tqdm, device=device)
task.perform_stats(net, tqdm=tqdm, device=device)

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))

Average Loss: 0.101, Accuracy: 96.827%


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

Average Loss: 0.099, Accuracy: 96.950%


(0.09941823184490203, 0.9695)

SmallNet: (0.6605741858482361, 0.7915)

ConvNet: (0.26250347793102263, 0.9235)

BigConvNet: (0.1045118197798729, 0.9668)
