In [1]:
import torch
from torchvision.datasets import MNIST
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

In [55]:
class LeNet5(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), padding=(2, 2))
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))
        self.pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.linear1 = nn.Linear(400, 120, bias=True)
        self.linear2 = nn.Linear(120, 84, bias=True)
        self.linear3 = nn.Linear(84, 10, bias=True)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        batch_sz = input.shape[0]
        x = torch.tanh(self.conv1(input))
        x = self.pool1(x)
        x = torch.tanh(self.conv2(x))
        x = self.pool2(x)
        x = x.view(batch_sz, 400)
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        x = self.linear3(x)
        return x

In [37]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100, bias=False)
        self.fc2 = nn.Linear(100, 10, bias=False)
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        B = input.shape[0]
        
        x = input.view(B, 28 * 28).contiguous()
        
        x = torch.relu(self.fc1(x))
        
        x = self.fc2(x)
        return x

In [118]:
model = LeNet5()
model = torch.compile(model)

In [38]:
model = MLP()
model = torch.compile(model)

In [39]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [40]:
train_dataset = MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
loader = DataLoader(train_dataset, 50, True)

In [41]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
model.train()
epochs = 10

for epoch in range(epochs):
    for i, (inp, outp) in enumerate(loader):
        optimizer.zero_grad()
        logits = model(inp)
        loss = criterion(logits, outp)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'Epoch: {epoch} \t| Step: {i}/{len(loader)}  \t| Loss: {loss.item():.4f}')

In [43]:
test_dataset = MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
test_loader = DataLoader(test_dataset, 1000, True)

In [None]:
model.eval()
pos = 0
length = 0

for i, (inp, outp) in enumerate(test_loader):
    logits = model(inp)
    true_pos = torch.softmax(logits, dim=-1).argmax(-1) == outp
    print(f'Accuracy of batch ({i}/{int(len(test_dataset)/test_loader.batch_size)}): {true_pos.sum()/test_loader.batch_size:.4f}')
    pos += true_pos.sum()
    length += test_loader.batch_size

print(f'Accuracy: {pos/length}')