In [1]:
# imports 

import torch; torch.manual_seed(0)
from torch import nn,utils,optim

import torchvision as tv

from torchvision import datasets, transforms

import lightning.pytorch as pl

import torchmetrics as tm
from torchmetrics import Metric
from tqdm.notebook import tqdm

In [2]:
class NN(pl.LightningModule):
    
    def __init__(self,input_shape,output_shape):
        super(NN,self).__init__()
        self.fc1 = nn.Linear(input_shape,50)
        self.fc2 = nn.Linear(50,output_shape)
        self.accuracy = tm.Accuracy(task="multiclass",num_classes=output_shape)
        self.f1_score = tm.F1Score(task="multiclass",num_classes=output_shape)
        
        self.lr= 1e-3
        
    def forward(self,x):
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, x_hat, y = self._common_step(batch,batch_idx)
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        
        self.log_dict({'train_loss':loss,
                      'train_accuracy':accuracy,
                      'train_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        if batch_idx % 100 == 0:
            x = x[:8]
            grid = tv.utils.make_grid(x.view(-1,1,28,28))
            self.logger.experiment.add_image('mnist_images',grid,self.global_step)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'val_loss':loss,
                      'val_accuracy':accuracy,
                      'val_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'test_loss':loss,
                      'test_accuracy':accuracy,
                      'test_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def _common_step(self,batch,batch_index):
        x, y = batch
        x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        loss = nn.functional.cross_entropy(x_hat,y)
        return loss , x_hat, y

    def predict_step(self,batch,batch_idx):
        x, y = batch
        x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        pred = torch.argmax(x_hat,dim=1)
        return pred


    def configure_optimizers(self):
        return optim.Adam(self.parameters(),lr=self.lr)

In [4]:
# Hyperparametersa
input_shape = 28*28
output_shape = 10
batch = 128
num_epoch = 2
learning_rate = 1e-3

model = NN(input_shape,output_shape).to('mps')


In [5]:
class MnistDataLoader(pl.LightningDataModule):
    def __init__(self,root,batch_size,num_workers):
        super(MnistDataLoader,self).__init__()
        self.root = root
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        
    def prepare_data(self):
        datasets.MNIST(root = self.root,train=True,download=True)
        datasets.MNIST(root = self.root,train=False,download=True)

    def setup(self, stage):
        train_dataset = datasets.MNIST(root = self.root,train=True,download=False,transform = transforms.ToTensor())
        
        # Define the proportions for the split
        train_proportion = 0.8  # 80% for training
        val_proportion = 0.2  # 20% for validation

        # Calculate the sizes of training and validation sets based on the proportions
        train_size = int(train_proportion * len(train_dataset))
        val_size = len(train_dataset) - train_size
        
        # Use random_split to split the dataset
        self.train_dataset, self.val_dataset = utils.data.random_split(train_dataset, [train_size, val_size])
        

    def train_dataloader(self):
        return utils.data.DataLoader(dataset=self.train_dataset,batch_size=self.batch_size,shuffle=True,num_workers=self.num_workers,pin_memory=True)
    
    def val_dataloader(self):
        return utils.data.DataLoader(dataset=self.val_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
    
    def test_dataloader(self):
        test_dataset =  datasets.MNIST(root = self.root,train=False,download=False,transform = transforms.ToTensor())
        return utils.data.DataLoader(dataset=test_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
   

In [6]:
class FineTuneBatchSizeFinder(pl.callbacks.BatchSizeFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.scale_batch_size(trainer, pl_module)

from lightning.pytorch.callbacks import LearningRateFinder


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)

In [8]:


logger = pl.loggers.TensorBoardLogger(save_dir='./dashboard/', name='mnist_clsfy', version=1.1)
profiler = pl.profilers.PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./dashboard/',),
    schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
)
dm = MnistDataLoader(root='./dataset/', batch_size=batch, num_workers=8)
trainer = pl.Trainer(
    logger=logger,
    accelerator='auto',
    devices=[0],
    min_epochs=1,
    max_epochs=50,
    precision='16-mixed',
    enable_model_summary=True,
    profiler=profiler,
    callbacks=[pl.callbacks.EarlyStopping('val_loss'),],
    default_root_dir="mnist_checkpoints/",
    enable_checkpointing  = True
    
)
trainer.fit(model, dm, ckpt_path="/Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/Swastik/swastik_web/Learning&Experiments/PytorchTutorials/pytorch lightning/dashboard/mnist_clsfy/version_1.1/checkpoints/epoch=19-step=7500.ckpt")
trainer.validate(model, dm)
trainer.test(model, dm)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/Swastik/swastik_web/Learning&Experiments/PytorchTutorials/pytorch lightning/dashboard/mnist_clsfy/version_1.1/checkpoints/epoch=19-step=7500.ckpt

  | Name     | Type               | Params
------------------------------------------------
0 | fc1      | Linear             | 39.2 K
1 | fc2      | Linear             | 510   
2 | accuracy | MulticlassAccuracy | 0     
3 | f1_score | MulticlassF1Score  | 0     
------------------------------------------------
39.8 K    Trainable params
0         Non-trainable params
39.8 K    Total params
0.159     Total estimated model params size (MB)
Restored all states from the checkpoint at /Users/pranavjha/Library/CloudStorage/Go

Sanity Checking: 0it [00:00, ?it/s]

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata


Training: 0it [00:00, ?it/s]

[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:330] Profiler is not initialized: skipping profiling metadata
[W kineto_shim.cpp:343] Profiler i

Validation: 0it [00:00, ?it/s]

STAGE:2023-06-18 01:00:52 36704:476496 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-06-18 01:00:52 36704:476496 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-06-18 01:00:52 36704:476496 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

FIT Profiler Report
Profile stats for: records
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         5.09%       1.249ms       100.00%      24.550ms      12.275ms             2  
[pl][profile][Strategy]SingleDeviceStrategy.validati...       -61.02%  -14981.000us        78.11%      19.177ms       9.588ms             2  
[pl][module]torchmetrics.classification.accuracy.Mul...        32.50%       7.979ms        69.71%      17.115ms       4.279ms             4  
[pl][module]torchmetrics.classification.f_beta.Multi...        31.23%       7.666ms        57.82%    

Validation: 0it [00:00, ?it/s]

STAGE:2023-06-18 01:01:34 36704:476496 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-06-18 01:01:34 36704:476496 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-06-18 01:01:34 36704:476496 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


VALIDATE Profiler Report
Profile stats for: records
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         3.62%     795.000us       100.00%      21.991ms      10.995ms             2  
[pl][profile][Strategy]SingleDeviceStrategy.validati...       -55.10%  -12116.000us        77.47%      17.036ms       8.518ms             2  
[pl][module]torchmetrics.classification.accuracy.Mul...        37.12%       8.163ms        67.92%      14.936ms       3.734ms             4  
[pl][module]torchmetrics.classification.f_beta.Multi...        27.25%       5.993ms        51.01

Testing: 0it [00:00, ?it/s]

STAGE:2023-06-18 01:01:40 36704:476496 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-06-18 01:01:41 36704:476496 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-06-18 01:01:41 36704:476496 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


TEST Profiler Report
Profile stats for: records
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         4.02%     801.000us       100.00%      19.950ms       9.975ms             2  
[pl][profile][Strategy]SingleDeviceStrategy.test_ste...       -63.11%  -12590.000us        79.54%      15.868ms       7.934ms             2  
[pl][module]torchmetrics.classification.accuracy.Mul...        41.96%       8.371ms        78.79%      15.718ms       3.929ms             4  
[pl][module]torchmetrics.classification.f_beta.Multi...        28.34%       5.653ms        52.88%   

[{'test_loss': 0.09235319495201111,
  'test_accuracy': 0.9743000268936157,
  'test_f1score': 0.9743000268936157}]

In [20]:
# !tensorboard --logdir="dashboard"