In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=0.1307, std=0.3081)
    ]
)

In [5]:
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

In [10]:
# data loader
# shuffle will prevent the model from memorizing the data
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# we don't want to shuffle since model is done learning
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [23]:
# we will not use a custom model

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # 64, 1, 28, 28 tensor. '64' images '1' is grayscale
        # linear vectors wants a single long list of vector
        # 28*28 = 784
        self.flatten = nn.Flatten() 
        self.layers = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
            x = self.flatten(x) #flatten the data
            x = self.layers(x) # pass it through layers
            return x
            

In [24]:
# device selection
device = torch.device("mps")
print(f"Using {device}")

Using mps


In [25]:
# initialize model and move it to devic
model = MNISTClassifier().to(device)

In [27]:
# loss function
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [28]:
def train_epoch(
    model,
    train_loader,    
    loss_function,
    optimizer,
    device
):

    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()

        # track progress
        running_loss += loss.item()
        _, predicted = output.max(1) # what is this doing?
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0 and batch_idx > 0:
            avg_loss = running_loss / 100
            accuracy = 100. * correct / total
            print(f'\t[{batch_idx * 64}/ 60000] '
                  f'Loss: {avg_loss:.3f} | Accuracy: {accuracy:.1f}%')
            running_loss = 0.0
        
        

	[6400/ 60000] Loss: 0.642 | Accuracy: 81.5%
	[12800/ 60000] Loss: 0.328 | Accuracy: 85.9%
	[19200/ 60000] Loss: 0.245 | Accuracy: 88.2%
	[25600/ 60000] Loss: 0.227 | Accuracy: 89.5%
	[32000/ 60000] Loss: 0.216 | Accuracy: 90.2%
	[38400/ 60000] Loss: 0.186 | Accuracy: 91.0%
	[44800/ 60000] Loss: 0.171 | Accuracy: 91.6%
	[51200/ 60000] Loss: 0.169 | Accuracy: 92.0%
	[57600/ 60000] Loss: 0.159 | Accuracy: 92.4%


In [30]:
# now evaluate the model
def evaluate(
    model,
    test_loader,
    device
):
    # set to eval mode
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        return 100.0 * correct / total

96.69833333333334

In [32]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f'\nEpoch: {epoch + 1}')
    train_epoch(
        model,
        train_loader,
        loss_function,
        optimizer,
        device    
    )
    accuracy = evaluate(
        model,
        train_loader,
        device
    )
    print(f'Test Accuracy: {accuracy:.2f}%')


Epoch: 1
	[6400/ 60000] Loss: 0.128 | Accuracy: 96.3%
	[12800/ 60000] Loss: 0.111 | Accuracy: 96.6%
	[19200/ 60000] Loss: 0.118 | Accuracy: 96.6%
	[25600/ 60000] Loss: 0.110 | Accuracy: 96.7%
	[32000/ 60000] Loss: 0.105 | Accuracy: 96.7%
	[38400/ 60000] Loss: 0.105 | Accuracy: 96.7%
	[44800/ 60000] Loss: 0.106 | Accuracy: 96.7%
	[51200/ 60000] Loss: 0.112 | Accuracy: 96.7%
	[57600/ 60000] Loss: 0.096 | Accuracy: 96.8%
Test Accuracy: 97.67%

Epoch: 2
	[6400/ 60000] Loss: 0.084 | Accuracy: 97.7%
	[12800/ 60000] Loss: 0.076 | Accuracy: 97.8%
	[19200/ 60000] Loss: 0.075 | Accuracy: 97.7%
	[25600/ 60000] Loss: 0.083 | Accuracy: 97.7%
	[32000/ 60000] Loss: 0.075 | Accuracy: 97.7%
	[38400/ 60000] Loss: 0.072 | Accuracy: 97.7%
	[44800/ 60000] Loss: 0.072 | Accuracy: 97.7%
	[51200/ 60000] Loss: 0.069 | Accuracy: 97.7%
	[57600/ 60000] Loss: 0.074 | Accuracy: 97.7%
Test Accuracy: 98.22%

Epoch: 3
	[6400/ 60000] Loss: 0.061 | Accuracy: 98.2%
	[12800/ 60000] Loss: 0.053 | Accuracy: 98.3%
	[19200/ 