In [1]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import SGD, Adam
from torchvision import datasets

In [2]:
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
data_dir = "./data/"
cifar_train_set = datasets.CIFAR10(data_dir + 'cifar10/', train = True, download = True)
cifar_test_set = datasets.CIFAR10(data_dir + 'cifar10/', train = False, download = True)

cifar_train_input = torch.from_numpy(cifar_train_set.data)
cifar_train_input = cifar_train_input.transpose(3, 1).transpose(2, 3).float()
cifar_train_target = torch.tensor(cifar_train_set.targets, dtype = torch.int64)

cifar_test_input = torch.from_numpy(cifar_test_set.data).float()
cifar_test_input = cifar_test_input.transpose(3, 1).transpose(2, 3).float()
cifar_test_target = torch.tensor(cifar_test_set.targets, dtype = torch.int64)


mnist_train_set = datasets.MNIST(data_dir + 'mnist/', train = True, download = True)
mnist_test_set = datasets.MNIST(data_dir + 'mnist/', train = False, download = True)

mnist_train_input = mnist_train_set.data.view(-1, 1, 28, 28).float()
mnist_train_target = mnist_train_set.targets
mnist_test_input = mnist_test_set.data.view(-1, 1, 28, 28).float()
mnist_test_target = mnist_test_set.targets

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class MnistClassifier(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3)
        self.maxp1 = nn.MaxPool2d(kernel_size=2)
        self.bn1 = nn.BatchNorm2d(num_features=16)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4)
        self.maxp2 = nn.MaxPool2d(kernel_size=2)
        self.bn2 = nn.BatchNorm2d(num_features=32)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
        self.bn3 = nn.BatchNorm1d(num_features=64)
        self.fc1 = nn.Linear(in_features=64, out_features=256)
        self.bn4 = nn.BatchNorm1d(num_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=10)
    def forward(self, x):
        y = self.bn1(F.relu(self.maxp1(self.conv1(x))))
        y = self.bn2(F.relu(self.maxp2(self.conv2(y))))
        y = self.bn3(F.relu(self.conv3(y).view(-1, 64)))
        y = self.bn4(F.relu(self.fc1(y)))
        y = F.relu(self.fc2(y))
        return y

In [5]:
def train_model(model, train_input, train_target, num_epochs=25, lr=1e-1, mini_batch_size=5000, adam=False):
    if adam:
        optimizer = Adam(model.parameters(), lr=lr)
    else:
        optimizer = SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss().cuda()
    num_samples = train_input.size(0)
    for e in range(1, num_epochs+1):
        sum_loss=0
        for b in range(0, num_samples, mini_batch_size):
            train_input_mini_batch = train_input[b:min(b+mini_batch_size, num_samples)]
            train_target_mini_batch = train_target[b:min(b+mini_batch_size, num_samples)]
            optimizer.zero_grad()
            prediction_mini_batch = model(train_input_mini_batch)
            loss = criterion(prediction_mini_batch, train_target_mini_batch)
            sum_loss += loss.item()
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
        print("Epoch {} Loss: {}".format(e, sum_loss))
    return model

In [6]:
def test_model(model, test_input, test_target):
    prediction = model(test_input)
    predicted_labels = torch.argmax(prediction, dim=1)
    accuracy = (predicted_labels == test_target).float().mean().item()
    return accuracy

In [7]:
mnist_model = train_model(MnistClassifier().cuda(), mnist_train_input.cuda(), mnist_train_target.cuda())

Epoch 1 Loss: 10.002651944756508
Epoch 2 Loss: 2.5359396636486053
Epoch 3 Loss: 1.6746613830327988
Epoch 4 Loss: 1.2898986637592316
Epoch 5 Loss: 1.0644440650939941
Epoch 6 Loss: 0.913803968578577
Epoch 7 Loss: 0.804564356803894
Epoch 8 Loss: 0.7211291193962097
Epoch 9 Loss: 0.6546393632888794
Epoch 10 Loss: 0.5997908115386963
Epoch 11 Loss: 0.5535517930984497
Epoch 12 Loss: 0.5137772783637047
Epoch 13 Loss: 0.4790307879447937
Epoch 14 Loss: 0.44852789118885994
Epoch 15 Loss: 0.4211531784385443
Epoch 16 Loss: 0.3963406253606081
Epoch 17 Loss: 0.3738534301519394
Epoch 18 Loss: 0.3531911950558424
Epoch 19 Loss: 0.3342803977429867
Epoch 20 Loss: 0.31666890904307365
Epoch 21 Loss: 0.30029630847275257
Epoch 22 Loss: 0.28502142056822777
Epoch 23 Loss: 0.27104074880480766
Epoch 24 Loss: 0.2578673753887415
Epoch 25 Loss: 0.24559614434838295


In [8]:
test_model(mnist_model, mnist_test_input.cuda(), mnist_test_target.cuda())

0.9872999787330627