In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

In [4]:
path = './datasets/'

transform = transforms.Compose([transforms.ToTensor()])

train_data = MNIST(root=path, train=True, transform=transform, download=True)
test_data = MNIST(root=path, train=False, transform=transform, download=True)

batch_size = 100

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, num_workers=4)

print(train_data)
print(test_data)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 183724319.89it/s]

Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 34219122.55it/s]


Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 40146117.66it/s]

Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4410865.66it/s]


Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./datasets/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./datasets/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )




In [5]:
train_data.classes # check the classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [6]:
input_shape = train_data[0][0].shape
output_shape = len(train_data.classes)

In [14]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5), stride=1, padding=2)
        self.pool1 = nn.AvgPool2d(kernel_size=(2,2), stride=2, padding=0)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5), stride=1)
        self.pool2 = nn.AvgPool2d(kernel_size=(2,2), stride=2, padding=0)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, output_shape)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x)) # shape : 100x6x28x28 print(x.shape)
        x = self.pool1(x) # shape : 100x6x14x14
        x = F.leaky_relu(self.conv2(x)) # shape : 100x16x10x10
        x = self.pool2(x) # shape : 100x16x5x5
        x = self.flatten(x)
        x = F.leaky_relu(self.fc1(x)) # shape : 400 -> 120
        x = F.leaky_relu(self.fc2(x)) # shape : 120 -> 84
        x = self.fc3(x) # shape : 84 -> 10
        return x

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [23]:
model = LeNet().to(device)
loss = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [24]:
num_epoch = 10
train_loss_lst, test_loss_lst = list(), list()

for i in range(num_epoch):
    # training
    model.train()

    total_loss = 0
    cnt = 0

    for batch_idx, (x,y) in enumerate(train_loader):

        x,y = x.to(device), y.to(device)
        y_est = model.forward(x)
        cost = loss(y_est, y)

        total_loss += cost.item()

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        pred = torch.argmax(y_est, dim=1)
        cnt += (pred == y).sum().item()

    acc = cnt / len(train_data)
    ave_loss = total_loss / len(train_data)

    train_loss_lst.append(ave_loss)

    if i % 1 == 0:
        print(f"\nEpoch {i} Train : {ave_loss:.3f} / {acc:.3f}")

    #testing
    model.eval()

    total_loss = 0
    cnt = 0

    with torch.no_grad():
        for batch, (x,y) in enumerate(test_loader):

            x, y = x.to(device), y.to(device)

            y_est = model.forward(x)
            pred = torch.argmax(y_est, dim=1)

            total_loss += cost.item()

        acc = cnt / len(test_data)
        ave_loss = total_loss / len(test_data)

        test_loss_lst.append(ave_loss)

        if i % 1 == 0:
            print(f"Epoch {i} Test : {ave_loss:.3f} / {acc:.3f}")

print()
num_parameter = 0
for parameter in model.parameters():
    print(parameter.shape)
    num_parameter += np.prod(parameter.size())
print(num_parameter)


Epoch 0 Train : 0.401 / 0.877
Epoch 0 Test : 0.161 / 0.000

Epoch 1 Train : 0.107 / 0.968
Epoch 1 Test : 0.045 / 0.000

Epoch 2 Train : 0.073 / 0.977
Epoch 2 Test : 0.035 / 0.000

Epoch 3 Train : 0.058 / 0.983
Epoch 3 Test : 0.053 / 0.000

Epoch 4 Train : 0.051 / 0.984
Epoch 4 Test : 0.077 / 0.000

Epoch 5 Train : 0.042 / 0.987
Epoch 5 Test : 0.016 / 0.000

Epoch 6 Train : 0.037 / 0.988
Epoch 6 Test : 0.014 / 0.000

Epoch 7 Train : 0.033 / 0.989
Epoch 7 Test : 0.025 / 0.000

Epoch 8 Train : 0.029 / 0.991
Epoch 8 Test : 0.007 / 0.000

Epoch 9 Train : 0.025 / 0.992
Epoch 9 Test : 0.044 / 0.000

torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])
61706
