In [1]:
%load_ext autoreload
% autoreload 2

In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from tensorboardX import SummaryWriter

from dataset import MNISTDataset
from models import CNN

In [3]:
def imshow(img, title=None):
    plt.imshow(img/(np.max(img)-np.min(img)))
    if title is not None:
        plt.title(title)
    plt.show()

In [18]:
data_dir = 'data/'
validation_split = 0.2
batch_size = 16
num_workers = 2
image_size = (64, 64)
num_classes = 10
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

In [19]:
dataset = MNISTDataset(os.path.join(data_dir, 'train_small.npz'), transform=transform)
train_size = int(np.floor(len(dataset)*(1-validation_split)))
[train_set, val_set] = random_split(dataset, [train_size, len(dataset) - train_size])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers)

In [22]:
dropout = 0.25

model = CNN(input_shape=(1, *image_size), dropout=dropout, num_classes=num_classes)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

In [23]:
num_epochs = 10
models_dir = 'models/'
model_name = 'CNN'

writer = SummaryWriter(comment=model_name)

best_val_loss = float('inf')
print('Started Training')
for epoch in range(num_epochs):
    
    model.train(True)
    train_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        x,y = data.values()
        optimizer.zero_grad()
        p = model(x)
        loss = criterion(p, y.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()    
    writer.add_scalar('train_loss', train_loss / (len(train_set)/batch_size), epoch)
    
    model.eval()
    val_loss = 0.0
    for i, data in enumerate(val_loader, 0):
        x,y = data.values()        
        p = model(x)
        loss = criterion(p, y.float())
        val_loss += loss.item()
    writer.add_scalar('validation_loss', val_loss / (len(val_set)/batch_size), epoch)
    
    scheduler.step(val_loss)
    
    print('[epoch %d] train_loss: %.5f, validation_loss: %.5f' % 
          (epoch + 1, train_loss / (len(train_set)/batch_size), val_loss / (len(val_set)/batch_size)))
    
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), os.path.join(models_dir, model_name))
        best_val_loss = val_loss
        print('Model Saved!')
        
    train_loss = 0.0
    val_loss = 0.0
            
print('Finished Training')

Started Training
[epoch 1] train_loss: 0.38461, validation_loss: 0.33102
Model Saved!
[epoch 2] train_loss: 0.34219, validation_loss: 0.32666
Model Saved!
[epoch 3] train_loss: 0.33858, validation_loss: 0.32622
Model Saved!
[epoch 4] train_loss: 0.33587, validation_loss: 0.32687
[epoch 5] train_loss: 0.33417, validation_loss: 0.32602
Model Saved!
[epoch 6] train_loss: 0.33372, validation_loss: 0.32664


Process Process-35:
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-36:
Traceback (most recent call last):
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/Users/abhi/.pyenv/versions/3.6.4/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[s

KeyboardInterrupt: 

KeyboardInterrupt
