In [1]:
from torchvision import datasets, transforms
import torch
import torch.nn.functional as F


# Hyperparams
batch_size = 50
loss_func = F.cross_entropy
epochs = 50

# GPU/CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
use_cuda = torch.cuda.is_available()

# Datasets
train_kwargs = {'batch_size': batch_size, 'shuffle': True}
val_kwargs = {'batch_size': batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    val_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
dataset1, dataset2 = torch.utils.data.random_split(dataset, [55000, 5000])

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
val_loader = torch.utils.data.DataLoader(dataset2, **val_kwargs)

cuda


In [2]:
import torch.nn as nn
from collections import OrderedDict

# model
    
model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1, 32, 5, 1, padding=(2, 2))),
    ('relu1', nn.ReLU()),
    ('maxpool1', nn.MaxPool2d(2)),
    ('conv2', nn.Conv2d(32, 64, 5, 1, padding=(2, 2))),
    ('relu2', nn.ReLU()),
    ('maxpool2', nn.MaxPool2d(2)),
    ('flatten', nn.Flatten()),
    ('linear1', nn.Linear(7 * 7 * 64, 1024)),
    ('relu3', nn.ReLU()),
    ('linear2', nn.Linear(1024, 10))
]))

model.to(device)

Sequential(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=3136, out_features=1024, bias=True)
  (relu3): ReLU()
  (linear2): Linear(in_features=1024, out_features=10, bias=True)
)

In [3]:
params = model[9].weight
bias = model[0].bias
params.shape

torch.Size([10, 1024])

In [4]:
# optimizer
import torch.optim as optim

learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

# Training
best_val_acc = 0

for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_func(output, target)
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        model.eval()
        val_acc = 0
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_acc += accuracy(output, target)
    
        val_acc = val_acc / len(val_loader)

        print(f'Validation accuracy after {epoch + 1} epoch(s): {val_acc}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # UNCOMMENT THE LINE BELOW TO SAVE THE MODEL!
            # torch.save(model, 'models/cnn_mnist.pt')
        

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Validation accuracy after 1 epoch(s): 0.9784002304077148
Validation accuracy after 2 epoch(s): 0.9880004525184631
Validation accuracy after 3 epoch(s): 0.9884003400802612
Validation accuracy after 4 epoch(s): 0.9910004138946533
Validation accuracy after 5 epoch(s): 0.9884001612663269
Validation accuracy after 6 epoch(s): 0.9902002215385437
Validation accuracy after 7 epoch(s): 0.9922001957893372
Validation accuracy after 8 epoch(s): 0.9902001619338989
Validation accuracy after 9 epoch(s): 0.9898003935813904
Validation accuracy after 10 epoch(s): 0.9914001822471619
Validation accuracy after 11 epoch(s): 0.9890002012252808
Validation accuracy after 12 epoch(s): 0.989800214767456
Validation accuracy after 13 epoch(s): 0.9888001680374146
Validation accuracy after 14 epoch(s): 0.990800142288208
Validation accuracy after 15 epoch(s): 0.992000162601471
Validation accuracy after 16 epoch(s): 0.9914001226425171
Validation accuracy after 17 epoch(s): 0.9916003346443176
Validation accuracy after 