In [None]:
import torch
from torch import nn
from torch import optim

import torchvision

import matplotlib.pyplot as plt
import numpy as np

In [None]:
class MLPNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(MLPNet, self).__init__()
    hid1 = hidden_size[0]
    hid2 = hidden_size[1]
    self.model = nn.Sequential(
                   nn.Linear(input_size, hid1),
                   nn.ReLU(),
                   nn.Linear(hid1, hid2),
                   nn.ReLU(),
                   nn.Linear(hid2, output_size)
                 )
    
  def forward(self, x):
    X = self.model(x)
    return X 

In [None]:
class CNNet(nn.Module):
  def __init__(self, in_channel, out_channel, h_w, kernel=3, stride=1, pad=1, dilation=1):
    super(CNNet, self).__init__()
    out_chnl1 = out_channel[0]
    out_chnl2 = out_channel[1]
    self.model = nn.Sequential(
                   nn.Conv2d(in_channel, out_chnl1, kernel_size=kernel, stride=stride, padding=pad, dilation=1),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=kernel),
                   nn.Conv2d(out_chnl1, out_chnl2, kernel_size=kernel, stride=stride, padding=pad),
                   nn.ReLU(),
                   nn.AdaptiveMaxPool2d(1),
                   nn.Flatten()
                 )
    self.out = nn.Sequential(
                 nn.Dropout(0.3),
                 nn.Linear(out_chnl2, 10)
               )
    
  def forward(self, x):
    X = self.model(x)
    return self.out(X)

In [None]:
def epoch(policy, opt, loss_func, train_loader, epoch_num, batch_size=32, model_type='mlp'):
  train_losses = []
  correct_count = []
  for batch_idx, (img_batch, target_batch) in enumerate(train_loader):
    batch_idx+=1

    if model_type=='mlp':
      out = policy(img_batch.view(batch_size,-1)).float()
    elif model_type=='cnn':
      if batch_size==1:
        out = policy(img_batch.unsqueeze(1)).float()
      else:
        out = policy(img_batch).float()
    
    loss = loss_func(out, target_batch.long())
    train_losses.append(loss)

    pred = out.data.argmax(dim=1)
    curr_corr_cnt = ( pred == target_batch ).sum()
    correct_count.append( curr_corr_cnt )

    opt.zero_grad()
    loss.backward()
    opt.step()

    # if batch_idx % 100 == 0:
    #   print("Loss: {:.5f}\nPred: \n{}\ntarget_batch: \n{}\n".format(loss, pred, target_batch))

  return policy, opt, train_losses, correct_count

In [None]:
def train(train_loader, policy, opt, loss_func, batch_size, model_type='mlp'):
  
  for e in range(num_epochs):
    policy, opt, losses, corr_count = epoch(policy, opt, loss_func, train_loader, e, batch_size=batch_size, model_type=model_type)
    print("Epoch: {}\tLoss: {:.5f}\tCorrect %{:.2f}".format( e, 
                                                             sum(losses)/len(losses), 
                                                             (100*sum(corr_count))/(len(corr_count*batch_size))))

  return policy, loss_func

In [None]:
def test(policy, test_loader, loss_func, batch_size=1, model_type='mlp'):  
  policy.eval()
  idx = 0
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for img, target in test_loader:
      if img.shape[0]==batch_size:
        idx+=1
        if model_type=='mlp':
          out = policy(img.view(batch_size,-1)).float()
        elif model_type=='cnn':
          if batch_size==1:
            out = policy(img.unsqueeze(0)).float()
          else:
            out = policy(img).float()
        test_loss += loss_func(out, target.long())
        pred = out.data.argmax(dim=1)
        correct+=( pred == target ).sum()
    print("Avg. Loss: {:.5f}\tCorrect %{:.2f}".format(float(test_loss/idx), 100*correct/(idx*batch_size)))

In [None]:
lr = 3e-3
batch_size = 128
log_freq = 500
num_epochs = 10

model_type = 'cnn'
hidden_sizes = [128,64]
out_channels = [64,64]

In [None]:
database = torchvision.datasets.MNIST

train_loader = torch.utils.data.DataLoader(
    database('/files/', 
            train=True, 
            download=True, 
            transform=torchvision.transforms.Compose([
                                                      torchvision.transforms.ToTensor(), 
                                                      torchvision.transforms.Normalize( (0.1307,), (0.3081,))])),  
    batch_size=batch_size,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    database('/files/', 
            train=False, 
            download=True, 
            transform=torchvision.transforms.Compose([
                                                      torchvision.transforms.ToTensor(), 
                                                      torchvision.transforms.Normalize( (0.1307,), (0.3081,))])),  
    batch_size=batch_size,
    shuffle=True)

h,w = train_loader.dataset.data.shape[1:]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /files/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/train-images-idx3-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /files/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/train-labels-idx1-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /files/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/t10k-images-idx3-ubyte.gz to /files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /files/MNIST/raw/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw



In [None]:
if model_type=='mlp':
  policy = MLPNet(h*w, hidden_sizes, len(train_loader.dataset.classes))
elif model_type=='cnn':
  policy = CNNet(1, out_channels, (h,w), kernel=2)
opt = optim.Adam(policy.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()

policy

CNNet(
  (model): Sequential(
    (0): Conv2d(1, 64, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): AdaptiveMaxPool2d(output_size=1)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (out): Sequential(
    (0): Dropout(p=0.3, inplace=False)
    (1): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [None]:
idx = 10
step = batch_size

In [None]:
img = train_loader.dataset.data[idx:idx+step].unsqueeze(1)
tar = train_loader.dataset.targets[idx:idx+step]
# plt.imshow(img)
with torch.no_grad():
  out = policy(img.float()).argmax(1)
  print(out)
print(out==tar)

tensor([6, 1, 7, 1, 1, 1, 6, 1, 1, 1, 1, 2, 1, 7, 6, 1, 7, 1, 6, 1, 1, 1, 1, 6,
        1, 1, 6, 0, 1, 6, 1, 1, 6, 1, 7, 6, 1, 1, 1, 1, 7, 1, 1, 1, 1, 1, 7, 1,
        1, 9, 1, 1, 6, 6, 1, 1, 1, 1, 6, 1, 1, 6, 1, 1, 6, 1, 1, 1, 1, 1, 6, 1,
        1, 1, 1, 1, 1, 5, 6, 1, 6, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 1, 1, 1, 6,
        1, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1, 6, 1, 1, 1, 7, 1, 1,
        6, 1, 1, 1, 1, 1, 1, 1])
tensor([False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False,  True, False, False,
         True, False,  True, False, False, False, False,  True,  True, False,
        False, False,

In [None]:
print("Untrained Validation...")
#print(test_loader.dataset.data.shape)
test(policy, test_loader, loss_func, batch_size, model_type='cnn')

Untrained Validation...
Avg. Loss: 2.36553	Correct %11.36


In [None]:
print("Training...")
policy, loss_func = train(train_loader, policy, opt, loss_func, batch_size, model_type=model_type)

Training...
Epoch: 0	Loss: 0.68391	Correct %77.94
Epoch: 1	Loss: 0.29659	Correct %90.53
Epoch: 2	Loss: 0.24901	Correct %91.99
Epoch: 3	Loss: 0.22248	Correct %92.87
Epoch: 4	Loss: 0.20770	Correct %93.29
Epoch: 5	Loss: 0.20310	Correct %93.41
Epoch: 6	Loss: 0.18950	Correct %93.88
Epoch: 7	Loss: 0.18951	Correct %93.89
Epoch: 8	Loss: 0.17852	Correct %94.22
Epoch: 9	Loss: 0.16978	Correct %94.45


In [None]:
print("Validation...")
#print(test_loader.dataset.data.shape)
test(policy, test_loader, loss_func, batch_size, model_type='cnn')

Validation...
Avg. Loss: 0.18099	Correct %94.52
