In [1]:
import numpy as np
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import tqdm

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
BATCH_SIZE = 64
set_seed(42)

In [7]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

display(train_dataset[0][0].shape)
display(train_dataset)

torch.Size([1, 28, 28])

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [8]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
          torch.nn.Linear(28*28, 512),
          torch.nn.ReLU(),
          torch.nn.Linear(512, 256),
          torch.nn.ReLU(),
          torch.nn.Linear(256, 10),
          torch.nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.model(x)
        return x


In [14]:
# SGD optimizer
model = Model()
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  print("epoch:", epoch, "loss:", total_loss)

  0%|          | 0/938 [00:00<?, ?it/s]

100%|██████████| 938/938 [00:15<00:00, 61.64it/s]


epoch: 0 loss: 2142.2526166439056


100%|██████████| 938/938 [00:14<00:00, 65.37it/s]


epoch: 1 loss: 2101.7511806488037


100%|██████████| 938/938 [00:14<00:00, 64.60it/s]


epoch: 2 loss: 2038.03302693367


100%|██████████| 938/938 [00:14<00:00, 64.89it/s]


epoch: 3 loss: 1927.701354265213


100%|██████████| 938/938 [00:14<00:00, 66.47it/s]


epoch: 4 loss: 1745.999971151352


100%|██████████| 938/938 [00:14<00:00, 66.00it/s]


epoch: 5 loss: 1492.310448884964


100%|██████████| 938/938 [00:14<00:00, 65.56it/s]


epoch: 6 loss: 1215.9128968715668


100%|██████████| 938/938 [00:14<00:00, 64.00it/s]


epoch: 7 loss: 983.2942970991135


100%|██████████| 938/938 [00:14<00:00, 66.22it/s]


epoch: 8 loss: 818.3162395954132


100%|██████████| 938/938 [00:14<00:00, 65.47it/s]

epoch: 9 loss: 707.0885758697987





In [15]:
# Momentum optimizer
model = Model()
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  print("epoch:", epoch, "loss:", total_loss)

100%|██████████| 938/938 [00:15<00:00, 60.61it/s]


epoch: 0 loss: 1582.850679576397


100%|██████████| 938/938 [00:15<00:00, 62.12it/s]


epoch: 1 loss: 499.8529797196388


100%|██████████| 938/938 [00:15<00:00, 62.11it/s]


epoch: 2 loss: 358.416724845767


100%|██████████| 938/938 [00:15<00:00, 60.89it/s]


epoch: 3 loss: 311.745112195611


100%|██████████| 938/938 [00:15<00:00, 62.17it/s]


epoch: 4 loss: 284.5541200712323


100%|██████████| 938/938 [00:15<00:00, 58.99it/s]


epoch: 5 loss: 263.3385721668601


100%|██████████| 938/938 [00:15<00:00, 60.28it/s]


epoch: 6 loss: 246.05668930336833


100%|██████████| 938/938 [00:15<00:00, 59.45it/s]


epoch: 7 loss: 229.74817308038473


100%|██████████| 938/938 [00:17<00:00, 54.95it/s]


epoch: 8 loss: 215.75678018853068


100%|██████████| 938/938 [00:16<00:00, 55.25it/s]

epoch: 9 loss: 202.44109741598368





In [16]:
# RMSprop optimizer
model = Model()
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  print("epoch:", epoch, "loss:", total_loss)

100%|██████████| 938/938 [00:19<00:00, 48.83it/s]


epoch: 0 loss: 185.63053293526173


100%|██████████| 938/938 [00:16<00:00, 58.22it/s]


epoch: 1 loss: 74.84468302549794


100%|██████████| 938/938 [00:17<00:00, 54.64it/s]


epoch: 2 loss: 49.93111440958455


100%|██████████| 938/938 [00:16<00:00, 57.74it/s]


epoch: 3 loss: 36.449686108899186


100%|██████████| 938/938 [00:16<00:00, 58.14it/s]


epoch: 4 loss: 27.909077187621733


100%|██████████| 938/938 [00:16<00:00, 55.35it/s]


epoch: 5 loss: 21.137203990227135


100%|██████████| 938/938 [00:17<00:00, 53.40it/s]


epoch: 6 loss: 17.686010420309685


100%|██████████| 938/938 [00:18<00:00, 51.07it/s]


epoch: 7 loss: 16.534293605292987


100%|██████████| 938/938 [00:18<00:00, 50.46it/s]


epoch: 8 loss: 12.671758254221913


100%|██████████| 938/938 [00:17<00:00, 52.55it/s]

epoch: 9 loss: 12.29824515752398





In [17]:
# Adam optimizer
model = Model()
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  print("epoch:", epoch, "loss:", total_loss)

100%|██████████| 938/938 [00:18<00:00, 51.79it/s]


epoch: 0 loss: 219.8822102751583


100%|██████████| 938/938 [00:19<00:00, 48.60it/s]


epoch: 1 loss: 82.25013864878565


100%|██████████| 938/938 [00:19<00:00, 47.98it/s]


epoch: 2 loss: 53.801378918695264


100%|██████████| 938/938 [00:20<00:00, 44.89it/s]


epoch: 3 loss: 39.508118202822516


100%|██████████| 938/938 [00:19<00:00, 48.91it/s]


epoch: 4 loss: 29.152213070803555


100%|██████████| 938/938 [00:18<00:00, 50.52it/s]


epoch: 5 loss: 24.11633635108592


100%|██████████| 938/938 [00:17<00:00, 53.26it/s]


epoch: 6 loss: 20.61195437531569


100%|██████████| 938/938 [00:17<00:00, 53.00it/s]


epoch: 7 loss: 18.003157269175063


100%|██████████| 938/938 [00:17<00:00, 52.92it/s]


epoch: 8 loss: 15.012630279899895


100%|██████████| 938/938 [00:17<00:00, 53.02it/s]

epoch: 9 loss: 14.977680848566706



