In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np

torch.manual_seed(1234)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.act2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(64, 64, 3, 2)
        self.act3 = nn.ReLU(inplace=True)
        self.fc4 = nn.Linear(1600, 10)

    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.act2(self.maxpool2(self.conv2(x)))
        x = self.act3(self.conv3(x))
        x = torch.flatten(x, 1)
        x = self.fc4(x)
        return x


In [None]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    correct = 0
    count = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += data.size(0)

        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\tAcc: {100. * correct / count}')

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")


In [None]:
train_kwargs = {'batch_size': 128}
test_kwargs = {'batch_size': 128}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    ])
dataset1 = datasets.MNIST('dataset', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('dataset', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
criterion = nn.CrossEntropyLoss()

scheduler = StepLR(optimizer, step_size=1)


In [None]:
n_epochs = 2
for epoch in range(1, n_epochs + 1):
    train(model, device, train_loader, criterion, optimizer, epoch)
    test(model, device, test_loader, criterion)
    scheduler.step()

In [None]:
dummy_input = torch.zeros(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "mnist.onnx")

In [None]:
# Save some training samples
train_dataset = datasets.MNIST('dataset', train=True)
test_dataset = datasets.MNIST('dataset', train=False)

N_CAL = 50
N_TEST = 1000

for idx, (img, target) in enumerate(train_dataset):
    if idx >= N_CAL:
        break
    img.save(f'samples/calibration/train_{idx:05d}_{target}.jpg')

for idx, (img, target) in enumerate(test_dataset):
    if idx >= N_TEST:
        break
    img.save(f'samples/test/test_{idx:05d}_{target}.jpg')