In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms
import torch.nn.functional as F

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST(
    root='~/.pytorch/MNIST_data/',
    download=True,
    train=True,
    transform=transform
)

trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=64,
    shuffle=True
)

In [None]:
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
dataiter = iter(trainloader)
images, labels = next(dataiter)
images = images.view(images.shape[0], -1)

In [None]:
logits = model(images)
loss = criterion(logits, labels)
print(loss)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [None]:
epochs = 5
for e in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        images = images.view(images.shape[0], -1)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {e+1}/{epochs} - Loss: {running_loss/len(trainloader):.4f}")

In [None]:
images, labels = next(iter(trainloader))
images = images.view(images.shape[0], -1)
with torch.no_grad():
    logits = model(images)
ps = F.softmax(logits, dim=1)

In [None]:
top_p, top_class = ps.topk(1, dim=1)
print(top_class[:10])
print(labels[:10])

In [None]:
correct = top_class.squeeze() == labels
accuracy = correct.float().mean()
accuracy

In [None]:
torch.save(model.state_dict(), 'mnist_model.pth')

In [None]:
model.load_state_dict(torch.load('mnist_model.pth'))