In [5]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [8]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
bs = 1
epochs = 1


Using cpu device


In [50]:
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 18),
            nn.ReLU(),
            nn.Linear(18, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        
def get_model():
    model = NeuralNetwork()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return model, loss_fn, optimizer

def get_data(training_data, test_data, bs):
    return (
        DataLoader(training_data, batch_size=bs), #, shuffle=True),
        DataLoader(test_data, batch_size=bs * 2),
    )

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    correct = 0
    size = len(train_dl.dataset)
    for epoch in range(epochs):
        model.train()
        for batch, (xb, yb) in enumerate(train_dl):
            # xb = xb.to(device), yb.to(device)
            loss, _ = loss_batch(model, loss_func, xb, yb, opt)
            correct += bool(model(xb).argmax() == yb)
            
            if batch % 100 == 0:
                current = (batch + 1) * len(xb)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        

        correct /= len(train_dl.dataset)

        print(f"Training Error: \n Accuracy: {(100*correct):>0.1f}%")


In [51]:
len(train_dl.dataset)

60000

In [52]:
bool(model(next(iter(train_dl))[0]).argmax() == next(iter(train_dl))[1])

True

In [53]:
torch.manual_seed(0)
train_dl, valid_dl = get_data(training_data, test_data, bs)
model, loss_func, opt = get_model()
torch.save(model.state_dict(), "model.pth")
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

loss: 2.457382  [    1/60000]
loss: 2.469971  [  101/60000]
loss: 2.292928  [  201/60000]
loss: 1.991651  [  301/60000]
loss: 2.224533  [  401/60000]
loss: 2.491241  [  501/60000]
loss: 2.182206  [  601/60000]
loss: 2.204501  [  701/60000]
loss: 2.122061  [  801/60000]
loss: 2.184120  [  901/60000]
loss: 1.960044  [ 1001/60000]
loss: 2.345141  [ 1101/60000]
loss: 1.907452  [ 1201/60000]
loss: 1.807153  [ 1301/60000]
loss: 1.795830  [ 1401/60000]
loss: 2.075871  [ 1501/60000]
loss: 2.231993  [ 1601/60000]
loss: 1.818370  [ 1701/60000]
loss: 2.145934  [ 1801/60000]
loss: 2.020865  [ 1901/60000]
loss: 2.630142  [ 2001/60000]
loss: 0.805722  [ 2101/60000]
loss: 1.297330  [ 2201/60000]
loss: 1.768152  [ 2301/60000]
loss: 1.540125  [ 2401/60000]
loss: 0.791902  [ 2501/60000]
loss: 2.116289  [ 2601/60000]
loss: 1.788357  [ 2701/60000]
loss: 1.698769  [ 2801/60000]
loss: 2.317448  [ 2901/60000]
loss: 2.271123  [ 3001/60000]
loss: 1.052840  [ 3101/60000]
loss: 1.676888  [ 3201/60000]
loss: 2.48

In [15]:
test_data

Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()