In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms, datasets

from tqdm.auto import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])


In [4]:
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

In [5]:
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

In [6]:
for images, labels in trainloader:
    print(images.shape, labels.shape)  # Print the shape of the batch of images and labels
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [7]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
).to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [20]:
n_epochs = 5

In [21]:
for epoch in tqdm(range(n_epochs)):

    losses = []

    for batch in tqdm(trainloader):
        
        optimizer.zero_grad()

        preds = model(batch[0].to(device))
        target = batch[1].to(device)

        loss = criterion(preds, target)
        loss.backward()

        losses.append(loss.item())

        optimizer.step()

    correct_preds = 0
    with torch.no_grad():
        for example in trainset:
            pred = model(example[0].to(device))
            target = example[1]
            if torch.argmax(pred) == target:
                correct_preds += 1

    test_correct_preds = 0
    with torch.no_grad():
        for example in testset:
            pred = model(example[0].to(device))
            target = example[1]
            if torch.argmax(pred) == target:
                test_correct_preds += 1

    print(f"Epoch: {epoch}; Avg loss: {sum(losses)/len(losses):.2f}; Train accuracy: {correct_preds/60000*100:.2f}; Test accuracy: {test_correct_preds/10000*100:.2f}")


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch: 0; Avg loss: 0.04; Train accuracy: 99.07; Test accuracy: 97.27


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch: 1; Avg loss: 0.03; Train accuracy: 99.18; Test accuracy: 97.39


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch: 2; Avg loss: 0.03; Train accuracy: 99.20; Test accuracy: 97.55


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch: 3; Avg loss: 0.03; Train accuracy: 99.15; Test accuracy: 97.37


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch: 4; Avg loss: 0.03; Train accuracy: 99.42; Test accuracy: 97.58


In [22]:
sample_input = torch.randn(1, 1, 28, 28)

In [23]:
model(sample_input.to(device))

tensor([[-18.7213, -16.9506,  17.8785,  -0.1805,  -9.9533, -12.6448, -19.7157,
          -9.9653,  -3.9585, -23.1484]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [24]:
onnx_program = torch.onnx.export(model, sample_input.to(device), "64x4_scratch.onnx")

In [25]:
torch.save(model.state_dict(), "./64x4_scratch.model")