In [None]:
### Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


In [None]:
### Dataloading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download = True, transform=transform
    )
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download = True, transform=transform
    )

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size = 64, shuffle = True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size = 64, shuffle = False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.1MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 506kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.54MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.24MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
### Architecture Definition
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 16, kernel_size = 3, padding = 1) # takes in 1 input channel (since MNIST images are grayscale) and outputs 16 feature maps
    self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) # a max pooling layer with a filter of size 2x2 and a stride of 2.
    self.conv2 = nn.Conv2d(16, 32, kernel_size = 3, padding = 1) # takes 16 input channels (from the previous layer) and outputs 32 feature maps.
    self.fc1 = nn.Linear(32*7*7, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.pool(torch.relu(self.conv1(x)))
    x = self.pool(torch.relu(self.conv2(x)))
    x = x.view(-1, 32 * 7 * 7)
    x = torch.relu(self.fc1(x))
    x = self.fc2(x)
    return x

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)
model = SimpleCNN().to(device)
print(model)

cuda
SimpleCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

epochs = 10
for epoch in range(epochs):
  running_loss = 0.0
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
  print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss/len(train_loader): 4f}")

Epoch [1/10], Loss:  0.191731
Epoch [2/10], Loss:  0.052720
Epoch [3/10], Loss:  0.036508
Epoch [4/10], Loss:  0.028369
Epoch [5/10], Loss:  0.021977
Epoch [6/10], Loss:  0.016491
Epoch [7/10], Loss:  0.012492
Epoch [8/10], Loss:  0.010984
Epoch [9/10], Loss:  0.009501
Epoch [10/10], Loss:  0.008709


In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
  for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)

    # display image with label
    # plt.imshow(images[0].cpu().numpy().squeeze(), cmap='gray')
    # plt.show()
    # print(f"Label: {labels[0].item()}")

    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    print(f"Predicted: {predicted[0].item()}")
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  print(f"Test Accuracy: {100 * correct / total:.2f}%")


Predicted: 7
Predicted: 7
Predicted: 8
Predicted: 0
Predicted: 2
Predicted: 9
Predicted: 6
Predicted: 9
Predicted: 2
Predicted: 3
Predicted: 1
Predicted: 0
Predicted: 1
Predicted: 7
Predicted: 0
Predicted: 7
Predicted: 4
Predicted: 9
Predicted: 9
Predicted: 7
Predicted: 1
Predicted: 6
Predicted: 0
Predicted: 3
Predicted: 6
Predicted: 3
Predicted: 8
Predicted: 1
Predicted: 7
Predicted: 3
Predicted: 4
Predicted: 2
Predicted: 7
Predicted: 2
Predicted: 2
Predicted: 7
Predicted: 0
Predicted: 4
Predicted: 2
Predicted: 2
Predicted: 3
Predicted: 3
Predicted: 1
Predicted: 2
Predicted: 1
Predicted: 1
Predicted: 0
Predicted: 9
Predicted: 8
Predicted: 7
Predicted: 9
Predicted: 1
Predicted: 7
Predicted: 9
Predicted: 3
Predicted: 4
Predicted: 0
Predicted: 1
Predicted: 6
Predicted: 5
Predicted: 9
Predicted: 2
Predicted: 5
Predicted: 1
Predicted: 8
Predicted: 6
Predicted: 9
Predicted: 3
Predicted: 3
Predicted: 6
Predicted: 9
Predicted: 4
Predicted: 8
Predicted: 0
Predicted: 7
Predicted: 7
Predicted: 1