Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

Setup Data

In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081))
    ])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./daa", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = 1000, shuffle=False)


Define Model

In [23]:
class SimpleCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.fc1 = nn.Linear(32 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 10)
    self.dropout = nn.Dropout(0.25)


  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 32 * 7 * 7)
    x = F.relu(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)
    return x

Create Objects

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

Train Loop

In [35]:
for epoch in range (1, 15):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)

    optimizer.zero_grad()
    output = model(data)
    loss = F.cross_entropy(output, target)
    loss.backward()
    optimizer.step()


  print(f"Epoch {epoch} Loss: {loss.item():.4f}")


Epoch 1 Loss: 0.0123
Epoch 2 Loss: 0.2139
Epoch 3 Loss: 0.0032
Epoch 4 Loss: 0.0811
Epoch 5 Loss: 0.0156
Epoch 6 Loss: 0.0005
Epoch 7 Loss: 0.0028
Epoch 8 Loss: 0.0014
Epoch 9 Loss: 0.0006
Epoch 10 Loss: 0.0059
Epoch 11 Loss: 0.0018
Epoch 12 Loss: 0.0124
Epoch 13 Loss: 0.0069
Epoch 14 Loss: 0.0019


Test Accuracy

In [37]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
  for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    output = model(data)
    pred = output.argmax(dim=1)
    correct += (pred == target).sum().item()
    total += target.size(0)

acc = 100. * correct / total
print(f"Test Accuracy: {acc:.2f}%")


Test Accuracy: 99.09%
