## libraries

In [2]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch import optim

## reading data

In [3]:
t = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train = MNIST(root='./data', train=True, transform=t, download=True)
test = MNIST(root='./data', train=False, transform=t, download=True)

x_train, y_train = train.data, train.targets
x_test, y_test = test.data, test.targets

100%|██████████| 9.91M/9.91M [00:00<00:00, 54.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.66MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 16.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.0MB/s]


## preprocessing

In [4]:
def preprocessing(x, y):
    x = x.type(torch.float32)
    x = (x - x.mean()) / x.std()
    x = x.unsqueeze(1)
    return x, y
x_train, y_train = preprocessing(x_train, y_train)
x_test, y_test = preprocessing(x_test, y_test)

## dataloader

In [5]:
train_set = TensorDataset(x_train, y_train)
train_set, valid_set = random_split(train_set, [50000, 10000])

test_set = TensorDataset(x_test, y_test)

train_loader = DataLoader(train_set, batch_size=256, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=256, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1024, shuffle=False)

## model

In [6]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((14, 14))
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((7, 7))
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((7, 7))
        )
        self.layer5 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=256 * 7 * 7, out_features=256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.layer6 = nn.Sequential(
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU()
        )
        self.layer7 = nn.Sequential(
            nn.Linear(in_features=128, out_features=10),
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        return out

## hyper parameters

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch = 30
loss_thresh = 1

## train

In [None]:
model.train()
for epoch in range(epoch):
    for i, (image, label) in enumerate(train_loader):
        image = image.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(image)
        loss = loss_fn(output, label)
        if loss.item() < loss_thresh:
            torch.save(model.state_dict(), 'model.ckpt')
            loss_thresh = loss.item()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"epoch: {epoch}, loss: {loss.item()}")

## validation

In [None]:
model.eval()
with torch.no_grad():
    for i, (image, label) in enumerate(valid_loader):
        image = image.to(device)
        label = label.to(device)
        output = model(image)
        loss = loss_fn(output, label)
        if i % 10 == 0:
            print(f"epoch: {epoch}, loss: {loss.item()}")