In [11]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
from torchmetrics import Accuracy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


<img alt="structure of the lenet network"  width="864" height="200" src="https://d2l.ai/_images/lenet.svg" style="color: transparent; height: auto; max-width: 100%;">

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            #1
            nn.Conv2d(in_channels = 1, out_channels= 6, kernel_size = 5), 
            nn.Tanh(), 
            nn.AvgPool2d(kernel_size = 2, stride = 2),

            #2
            nn.Conv2d(in_channels = 6, out_channels= 16, kernel_size = 5), 
            nn.Tanh(), 
            nn.AvgPool2d(kernel_size = 2, stride = 2)
        )

        self.classifier = nn.Sequential(
            #3
            nn.Flatten(),
            nn.Linear(in_features = 16*4*4, out_features = 120),
            nn.Tanh(),
            #4
            nn.Linear(in_features = 120, out_features = 84),
            nn.Tanh(),
            #5
            nn.Linear(in_features = 84, out_features = 10),
        )

    def forward(self, input_):
        y = self.features(input_)
        y = self.classifier(y)
        return y

In [4]:
class customData(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)
        

In [14]:
def train(batched_dataset, X_valid, y_valid, learning_rate, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    accuracy = Accuracy(task="multiclass", num_classes=10).to(device)
    print("======================================================== \n\n")
    
    for _ in range(epochs):
        
        total_epoch_Loss = 0
        for X, y in batched_dataset:

            X, y = X.to(device), y.to(device) 

            #forward pass
            pred = model(X)
            #Loss calculation 
            loss = criterion(pred, y)
            #reset Grad
            optimizer.zero_grad()
            #backprop
            loss.backward()
            #upgrade Grad
            optimizer.step()

            total_epoch_Loss += loss.item()

        avg_loss = total_epoch_Loss / len(batched_dataset)
        print(f'Epochs: {_ + 1} | Loss: {avg_loss:.4f}')

        if _ % 20 == 0:
            
            model.eval()
            with torch.no_grad():    
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)
                pred = model(X_valid)
                loss = criterion(pred, y_valid)
                preds = torch.argmax(pred, dim=1)
                acc = accuracy(preds, y_valid)
                print("###############################")
                print(f'Validation Loss: {loss.item():.4f} | Accuracy: {acc.item():.4f}')
                print("###############################")


    print("========================================================\n\n")

In [6]:
df = pd.read_csv("data/mnist_train.csv")
df.shape

(60000, 785)

In [7]:
torch.manual_seed(0)
y = df['label']
y = torch.tensor(y, dtype=torch.long)
X = df.drop('label', axis = 1).values
X = torch.tensor(X, dtype=torch.float32).reshape(60000, 1, 28, 28)
dataset = customData(X, y)
batches = DataLoader(dataset, batch_size = 32 , shuffle= True)

In [8]:
df_t = pd.read_csv("data/mnist_test.csv")

y_test = df_t['label']
y_test = torch.tensor(y_test, dtype=torch.long)
X_test = df_t.drop('label', axis = 1).values
X_test = torch.tensor(X_test, dtype=torch.float32).reshape(10000, 1, 28, 28)

In [15]:
model = LeNet().to(device)
train(batched_dataset = batches, X_valid = X_test, y_valid = y_test, learning_rate = 0.001, epochs = 100)



Epochs: 1 | Loss: 0.1874
###############################
Validation Loss: 0.0779 | Accuracy: 0.9734
###############################
Epochs: 2 | Loss: 0.0634
Epochs: 3 | Loss: 0.0465
Epochs: 4 | Loss: 0.0403
Epochs: 5 | Loss: 0.0334
Epochs: 6 | Loss: 0.0276
Epochs: 7 | Loss: 0.0249
Epochs: 8 | Loss: 0.0231
Epochs: 9 | Loss: 0.0199
Epochs: 10 | Loss: 0.0183
Epochs: 11 | Loss: 0.0149
Epochs: 12 | Loss: 0.0144
Epochs: 13 | Loss: 0.0134
Epochs: 14 | Loss: 0.0150
Epochs: 15 | Loss: 0.0123
Epochs: 16 | Loss: 0.0128
Epochs: 17 | Loss: 0.0121
Epochs: 18 | Loss: 0.0097
Epochs: 19 | Loss: 0.0090
Epochs: 20 | Loss: 0.0093
Epochs: 21 | Loss: 0.0096
###############################
Validation Loss: 0.0491 | Accuracy: 0.9881
###############################
Epochs: 22 | Loss: 0.0086
Epochs: 23 | Loss: 0.0078
Epochs: 24 | Loss: 0.0082
Epochs: 25 | Loss: 0.0091
Epochs: 26 | Loss: 0.0062
Epochs: 27 | Loss: 0.0070
Epochs: 28 | Loss: 0.0106
Epochs: 29 | Loss: 0.0087
Epochs: 30 | Loss: 0.0073
Epochs: 31 | 

In [18]:
torch.save(model.state_dict(), 'mnist_model.pt')