In [1]:
# Import MNIST dataset from torchvision
import torch
import torchvision
from torch.utils.data import random_split, DataLoader, TensorDataset
import torch.nn as nn

# Load Data
mnist_dataset = torchvision.datasets.MNIST(root = ".", download = True)

train_size = int(0.9 * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size

# Train / Val Split
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])

train_x_data = train_dataset.dataset.data[train_dataset.indices]
train_y_data = train_dataset.dataset.targets[train_dataset.indices]
val_x_data   = val_dataset.dataset.data[val_dataset.indices]
val_y_data   = val_dataset.dataset.targets[val_dataset.indices]

# Comb TensorDataset
train_tensor_dataset = TensorDataset(train_x_data, train_y_data)
val_tensor_dataset = TensorDataset(val_x_data, val_y_data)

# DataLoader
train_loader = DataLoader(train_tensor_dataset, batch_size=128, shuffle=True, num_workers=0)
val_loader = DataLoader(val_tensor_dataset, batch_size=128, shuffle=False, num_workers=0)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 505kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.98MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.92MB/s]


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
  # Your Model is Defined Here
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2),
        nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5),
        nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120),
        nn.Linear(120, 84),
        nn.Linear(84,10)
    )

  def forward(self, x):
    return self.net(x)

model = LeNet()

In [9]:
from tqdm import tqdm

max_epochs = 20
learning_rate = 5 * 1e-5
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for epoch in tqdm(range(max_epochs)):
    model.train()
    total = 0
    correct = 0
    cur_loss = 0.0

    for train_x, train_y in train_loader:
        train_x = train_x.float().to(device)       # [batch, 28, 28]
        train_y = train_y.to(device)               # [batch]
        if train_x.dim() == 3:
            train_x = train_x.unsqueeze(1)         # [batch, 1, 28, 28]

        predict_y = model(train_x)                 # [batch, 10]
        loss = criterion(predict_y, train_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cur_loss += loss.item() * train_x.size(0)  # total loss
        preds = predict_y.argmax(dim=1)            # [batch]
        correct += (preds == train_y).sum().item()
        total += train_x.size(0)

    avg_loss = cur_loss / total
    acc = correct / total

    val_total = 0
    val_correct = 0
    for val_x, val_y in val_loader:
      val_x = val_x.float().to(device)
      val_y = val_y.to(device)
      if val_x.dim() == 3:
          val_x = val_x.unsqueeze(1)         # [batch, 1, 28, 28]
      predict_val_y = model(val_x)                 # [batch, 10]

      preds = predict_val_y.argmax(dim=1)            # [batch]
      val_correct += (preds == val_y).sum().item()
      val_total += val_x.size(0)

    val_acc = val_correct / val_total
    print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Acc={acc:.4f}, Val_Acc={val_acc:.4f}")


  3%|▎         | 1/30 [00:01<00:43,  1.49s/it]

Epoch 0: Loss=0.1045, Acc=0.9691, Val_Acc=0.9657


  7%|▋         | 2/30 [00:02<00:40,  1.46s/it]

Epoch 1: Loss=0.1029, Acc=0.9694, Val_Acc=0.9668


 10%|█         | 3/30 [00:04<00:39,  1.45s/it]

Epoch 2: Loss=0.1014, Acc=0.9706, Val_Acc=0.9680


 13%|█▎        | 4/30 [00:05<00:37,  1.45s/it]

Epoch 3: Loss=0.0999, Acc=0.9706, Val_Acc=0.9665


 17%|█▋        | 5/30 [00:07<00:38,  1.53s/it]

Epoch 4: Loss=0.0991, Acc=0.9708, Val_Acc=0.9650


 20%|██        | 6/30 [00:09<00:38,  1.61s/it]

Epoch 5: Loss=0.0982, Acc=0.9712, Val_Acc=0.9657


 23%|██▎       | 7/30 [00:10<00:35,  1.54s/it]

Epoch 6: Loss=0.0965, Acc=0.9717, Val_Acc=0.9652


 27%|██▋       | 8/30 [00:12<00:33,  1.52s/it]

Epoch 7: Loss=0.0955, Acc=0.9718, Val_Acc=0.9693


 30%|███       | 9/30 [00:13<00:31,  1.50s/it]

Epoch 8: Loss=0.0949, Acc=0.9719, Val_Acc=0.9675


 33%|███▎      | 10/30 [00:15<00:29,  1.49s/it]

Epoch 9: Loss=0.0937, Acc=0.9719, Val_Acc=0.9685


 37%|███▋      | 11/30 [00:16<00:28,  1.48s/it]

Epoch 10: Loss=0.0927, Acc=0.9728, Val_Acc=0.9692


 40%|████      | 12/30 [00:17<00:26,  1.48s/it]

Epoch 11: Loss=0.0921, Acc=0.9724, Val_Acc=0.9680


 43%|████▎     | 13/30 [00:19<00:26,  1.54s/it]

Epoch 12: Loss=0.0910, Acc=0.9730, Val_Acc=0.9692


 47%|████▋     | 14/30 [00:21<00:25,  1.60s/it]

Epoch 13: Loss=0.0899, Acc=0.9729, Val_Acc=0.9688


 50%|█████     | 15/30 [00:22<00:23,  1.56s/it]

Epoch 14: Loss=0.0889, Acc=0.9737, Val_Acc=0.9702


 53%|█████▎    | 16/30 [00:24<00:21,  1.54s/it]

Epoch 15: Loss=0.0885, Acc=0.9737, Val_Acc=0.9707


 57%|█████▋    | 17/30 [00:25<00:19,  1.52s/it]

Epoch 16: Loss=0.0879, Acc=0.9739, Val_Acc=0.9710


 60%|██████    | 18/30 [00:27<00:18,  1.51s/it]

Epoch 17: Loss=0.0868, Acc=0.9742, Val_Acc=0.9723


 63%|██████▎   | 19/30 [00:28<00:16,  1.50s/it]

Epoch 18: Loss=0.0859, Acc=0.9743, Val_Acc=0.9705


 67%|██████▋   | 20/30 [00:30<00:14,  1.48s/it]

Epoch 19: Loss=0.0856, Acc=0.9745, Val_Acc=0.9713


 70%|███████   | 21/30 [00:32<00:14,  1.57s/it]

Epoch 20: Loss=0.0845, Acc=0.9746, Val_Acc=0.9718


 73%|███████▎  | 22/30 [00:33<00:12,  1.60s/it]

Epoch 21: Loss=0.0843, Acc=0.9749, Val_Acc=0.9713


 77%|███████▋  | 23/30 [00:35<00:10,  1.57s/it]

Epoch 22: Loss=0.0827, Acc=0.9753, Val_Acc=0.9727


 80%|████████  | 24/30 [00:36<00:09,  1.53s/it]

Epoch 23: Loss=0.0824, Acc=0.9754, Val_Acc=0.9728


 83%|████████▎ | 25/30 [00:38<00:07,  1.51s/it]

Epoch 24: Loss=0.0819, Acc=0.9759, Val_Acc=0.9715


 87%|████████▋ | 26/30 [00:39<00:05,  1.50s/it]

Epoch 25: Loss=0.0811, Acc=0.9759, Val_Acc=0.9733


 90%|█████████ | 27/30 [00:40<00:04,  1.48s/it]

Epoch 26: Loss=0.0802, Acc=0.9760, Val_Acc=0.9732


 93%|█████████▎| 28/30 [00:42<00:02,  1.48s/it]

Epoch 27: Loss=0.0800, Acc=0.9761, Val_Acc=0.9712


 97%|█████████▋| 29/30 [00:44<00:01,  1.58s/it]

Epoch 28: Loss=0.0791, Acc=0.9765, Val_Acc=0.9722


100%|██████████| 30/30 [00:45<00:00,  1.53s/it]

Epoch 29: Loss=0.0785, Acc=0.9768, Val_Acc=0.9728



