In [6]:
import torch.nn as nn
import os
import torch.nn.functional as F
from collections import Counter
from torchvision import transforms, datasets
import torchvision
import torchmetrics
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch
import matplotlib.pyplot as plt
import torchinfo
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

In [2]:
BATCH_SIZE = 256
NUM_EPOCHS = 5
LEARNING_RATE = 0.01
NUM_WORKERS = 10 

In [3]:
class DataModule(pl.LightningDataModule): 
    def __init__(self, data_path = '../../data'): 
        super().__init__()
        self.data_path = data_path
        
    def prepare_data(self):
        
        datasets.MNIST(root = self.data_path, 
                       download = True)
        
        self.transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            ])
        
    def setup(self, stage = None):
        
        train = datasets.MNIST(root = self.data_path, 
                                train = True, 
                                transform = self.transform, 
                                download = False)
        
        self.test = datasets.MNIST(root = self.data_path, 
                                   train = False, 
                                   transform = self.transform, 
                                   download = False)
        
        self.train, self.val = random_split(train, [55000, 5000])
        
    
    def train_dataloader(self): 
        
        train_loader = DataLoader(self.train,
                                    batch_size = BATCH_SIZE,
                                    num_workers = NUM_WORKERS,
                                    shuffle = True, 
                                    persistent_workers=True)
        return train_loader
    
    def val_dataloader(self):
            
            val_loader = DataLoader(self.val,
                                    batch_size = BATCH_SIZE,
                                    num_workers = NUM_WORKERS,
                                    shuffle = False, 
                                    persistent_workers=True)
            return val_loader
    
    def test_dataloader(self):
            
            test_loader = DataLoader(self.test,
                                    batch_size = BATCH_SIZE,
                                    num_workers = NUM_WORKERS,
                                    shuffle = False, 
                                    persistent_workers=True)
            return test_loader

In [5]:
class PytorchLeNet(nn.Module): 
    def __init__(self, num_classes, grayscale = False):
        super(PytorchLeNet, self).__init__()
        
        self.grayscale = grayscale
        self.num_classes = num_classes
        
        if self.grayscale:
            in_channels = 1
        else:
            in_channels = 3
        
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 6, kernel_size = 5),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(6, 16, kernel_size = 5),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size = 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [10]:
model = PytorchLeNet(num_classes=10, grayscale=True)
summary = torchinfo.summary(model, (32, 1, 32, 32), 
                            col_names=("input_size", "output_size", "num_params"),
                            col_width=18)

summary

Layer (type:depth-idx)                   Input Shape        Output Shape       Param #
PytorchLeNet                             [32, 1, 32, 32]    [32, 10]           --
├─Sequential: 1-1                        [32, 1, 32, 32]    [32, 16, 5, 5]     --
│    └─Conv2d: 2-1                       [32, 1, 32, 32]    [32, 6, 28, 28]    156
│    └─Tanh: 2-2                         [32, 6, 28, 28]    [32, 6, 28, 28]    --
│    └─MaxPool2d: 2-3                    [32, 6, 28, 28]    [32, 6, 14, 14]    --
│    └─Conv2d: 2-4                       [32, 6, 14, 14]    [32, 16, 10, 10]   2,416
│    └─Tanh: 2-5                         [32, 16, 10, 10]   [32, 16, 10, 10]   --
│    └─MaxPool2d: 2-6                    [32, 16, 10, 10]   [32, 16, 5, 5]     --
├─Sequential: 1-2                        [32, 400]          [32, 10]           --
│    └─Linear: 2-7                       [32, 400]          [32, 120]          48,120
│    └─Tanh: 2-8                         [32, 120]          [32, 120]          --
│  

In [13]:
class LightningModel(pl.LightningModule): 
    def __init__(self, model, learning_rate): 
        super(LightningModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        
        self.save_hyperparameters(ignore=['model'])
        
        
        self.train_acc = torchmetrics.Accuracy(task= 'multiclass', num_classes = 10)
        self.val_acc = torchmetrics.Accuracy(task= 'multiclass', num_classes = 10)
        self.test_acc = torchmetrics.Accuracy(task= 'multiclass', num_classes = 10)
        
    def shared_step (self, batch):
        x, y = batch
        logits = self.model(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        return loss, y, logits
    
    def training_step(self, batch, batch_idx):
        loss, y, logits = self.shared_step(batch)
        self.train_acc(logits, y)
        self.log('train_loss', loss)
        
        self.model.eval()
        with torch.inference_mode():
            _, y, logits = self.shared_step(batch)
            self.train_acc(logits, y)
            
            self.log('train_acc', self.train_acc, on_step = False, on_epoch = True)
        return loss
    
    def testing_step(self, batch, batch_idx):
        loss, y, logits = self.shared_step(batch)
        self.test_acc(logits, y)
        self.log('test_loss', loss)
        self.log('test_acc', self.test_acc, on_step = False, on_epoch = True)
        
    def validation_step(self, batch, batch_idx):
        loss, y, logits = self.shared_step(batch)
        self.val_acc(logits, y)
        self.log('val_loss', loss)
        self.log('val_acc', self.val_acc, on_step = False, on_epoch = True)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
        return optimizer
        

In [14]:
torch.manual_seed(42)

data_module = DataModule(data_path='../../data')
model = PytorchLeNet(num_classes=10, grayscale=True)

lightning_model = LightningModel(model, LEARNING_RATE)

callbacks = [ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=1)]
logger = CSVLogger('logs', name='LeNet_MNIST')

In [15]:

import time 
start = time.time()

trainer = pl.Trainer(max_epochs=NUM_EPOCHS,
                     callbacks=callbacks,
                        logger=logger,
                        devices='auto', 
                        accelerator='auto', 
                        log_every_n_steps=100)

trainer.fit(lightning_model, data_module)
    
end = time.time()   

print(f"Training time: {end-start}")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/LeNet_MNIST

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | PytorchLeNet       | 61.7 K
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
61.7 K    Trainable params
0         Non-trainable params
61.7 K    Total params
0.247     Total estimated model params size (MB)
python(78751) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

python(78752) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78808) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78809) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78810) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78811) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78812) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78813) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78814) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78815) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78816) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78817) Malloc

                                                                           

python(78873) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78874) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78875) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78876) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78877) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78878) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78879) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78880) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78881) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(78883) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch 0:   0%|          | 0/215 [00:00<?, ?it/s] 