In [16]:
import torch; torch.manual_seed(0)
from torch import nn,utils,optim

import torchvision
from torchvision import datasets, transforms

import lightning.pytorch as pl
import torchmetrics

from tqdm.notebook import tqdm

In [17]:
class MnistDataLoader(pl.LightningDataModule):
    def __init__(self,root,batch_size,num_workers):
        super(MnistDataLoader,self).__init__()
        self.root = root
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        
    def prepare_data(self):
        datasets.MNIST(root = self.root,train=True,download=True)
        datasets.MNIST(root = self.root,train=False,download=True)

    def setup(self, stage):
        train_dataset = datasets.MNIST(root = self.root,train=True,download=False,transform = transforms.ToTensor())
        self.test_dataset =  datasets.MNIST(root = self.root,train=False,download=False,transform = transforms.ToTensor())
        # Define the proportions for the split
        train_proportion = 0.8  # 80% for training
        val_proportion = 0.2  # 20% for validation

        # Calculate the sizes of training and validation sets based on the proportions
        train_size = int(train_proportion * len(train_dataset))
        val_size = len(train_dataset) - train_size
        
        # Use random_split to split the dataset
        self.train_dataset, self.val_dataset = utils.data.random_split(train_dataset, [train_size, val_size])
        

    def train_dataloader(self):
        return utils.data.DataLoader(dataset=self.train_dataset,batch_size=self.batch_size,shuffle=True,num_workers=self.num_workers,pin_memory=True)
    
    def val_dataloader(self):
        return utils.data.DataLoader(dataset=self.val_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
    
    def test_dataloader(self):
        return utils.data.DataLoader(dataset=self.test_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
   

In [18]:
root = './data'
batch_size = 128
num_workers = 4
ds = MnistDataLoader(root, batch_size, num_workers)

In [19]:
# # Only for inspecting data 
# ds.prepare_data()
# ds.setup('train')
# for data, label in ds.train_dataloader().dataset:
#     print(data.shape,label)
#     break

In [20]:
class CNN(pl.LightningModule):
    
    def __init__(self,n_channel,output_shape):
        super(CNN,self).__init__()
        self.save_hyperparameters()
        
        self.conv1 = nn.Conv2d(in_channels=n_channel,out_channels=8,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv2 =nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.fc1 = nn.Linear(in_features=16*7*7,out_features=n_out)
        
        self.accuracy = torchmetrics.Accuracy(task="multiclass",num_classes=output_shape)
        self.f1_score = torchmetrics.F1Score(task="multiclass",num_classes=output_shape)
        
        self.lr= 1e-3
        
    def forward(self,x):
        x = nn.functional.relu(self.conv1(x))
        x = self.pool(x)
        x = nn.functional.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, x_hat, y = self._common_step(batch,batch_idx)
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        
        self.log_dict({'train_loss':loss,
                      'train_accuracy':accuracy,
                      'train_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        if batch_idx % 100 == 0:
            x = x[:8]
            grid = torchvision.utils.make_grid(x.view(-1,1,28,28))
            self.logger.experiment.add_image('mnist_images',grid,self.global_step)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'val_loss':loss,
                      'val_accuracy':accuracy,
                      'val_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'test_loss':loss,
                      'test_accuracy':accuracy,
                      'test_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def _common_step(self,batch,batch_index):
        x, y = batch
#         x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        loss = nn.functional.cross_entropy(x_hat,y)
        return loss , x_hat, y

    def predict_step(self,batch,batch_idx):
        x, y = batch
#         x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        pred = torch.argmax(x_hat,dim=1)
        return pred


    def configure_optimizers(self):
        return optim.Adam(self.parameters(),lr=self.lr)

In [21]:
# Hyperparametersa
n_channel,output_shape = 1, 10
batch = 128
num_epoch = 2
learning_rate = 1e-3

model = CNN(n_channel,output_shape)

In [22]:


logger = pl.loggers.TensorBoardLogger(save_dir='./log/', name='mnist_cnn', version=0.1)

profiler = pl.profilers.PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/',),
    schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
)

# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=3,
    monitor="val_f1score",
    mode="max",
    dirpath="checkpoints/cnn/",
    filename="mnist-cnn-{epoch}-{val_accuracy}",
)


trainer = pl.Trainer(
    logger=logger,
    accelerator='auto',
    devices=[0],
    min_epochs=1,
    max_epochs=100,
    precision='16-mixed',
    enable_model_summary=True,
#     profiler=profiler,
    callbacks=[pl.callbacks.EarlyStopping('val_loss',patience=6,verbose=True),checkpoint_callback],
#     default_root_dir="mnist_checkpoints/",
    enable_checkpointing  = True
    
)
if checkpoint_callback.best_model_path:
    trainer.fit(model, ds, ckpt_path=checkpoint_callback.best_model_path,)
else : 
    trainer.fit(model, ds)
trainer.validate(model, ds)
# trainer.test(model, ds)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params
------------------------------------------------
0 | conv1    | Conv2d             | 80    
1 | pool     | MaxPool2d          | 0     
2 | conv2    | Conv2d             | 1.2 K 
3 | fc1      | Linear             | 7.9 K 
4 | accuracy | MulticlassAccuracy | 0     
5 | f1_score | MulticlassF1Score  | 0     
------------------------------------------------
9.1 K     Trainable params
0         Non-trainable params
9.1 K     Total params
0.036     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.200


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.078 >= min_delta = 0.0. New best score: 0.121


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 0.090


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 0.077


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.069


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.069


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.062


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.059


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.056


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.054


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.051


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.048


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 6 records. Best score: 0.048. Signaling Trainer to stop.


Validation: 0it [00:00, ?it/s]

[{'val_loss': 0.025710148736834526,
  'val_accuracy': 0.9913333058357239,
  'val_f1score': 0.9913333058357239}]

In [23]:
trainer.test(model, ds,ckpt_path=checkpoint_callback.best_model_path)

Restoring states from the checkpoint path at /Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/applied_theories/lightning examples/checkpoints/cnn/mnist-cnn-epoch=19-val_f1score=0.98458331823349.ckpt
Loaded model weights from the checkpoint at /Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/applied_theories/lightning examples/checkpoints/cnn/mnist-cnn-epoch=19-val_f1score=0.98458331823349.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.04413893446326256,
  'test_accuracy': 0.98580002784729,
  'test_f1score': 0.98580002784729}]

In [24]:
!tensorboard --logdir log/

TensorFlow installation not found - running with reduced feature set.
I0722 09:58:35.907079 6169096192 plugin.py:429] Monitor runs begin
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C
