In [1]:
import torch; torch.manual_seed(0)
from torch import nn,utils,optim

import torchvision
from torchvision import datasets, transforms

import lightning.pytorch as pl
import torchmetrics

from tqdm.notebook import tqdm

In [2]:
class MnistDataLoader(pl.LightningDataModule):
    def __init__(self,root,batch_size,num_workers):
        super(MnistDataLoader,self).__init__()
        self.root = root
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        
    def prepare_data(self):
        datasets.MNIST(root = self.root,train=True,download=True)
        datasets.MNIST(root = self.root,train=False,download=True)

    def setup(self, stage):
        train_dataset = datasets.MNIST(root = self.root,train=True,download=False,transform = transforms.ToTensor())
        self.test_dataset =  datasets.MNIST(root = self.root,train=False,download=False,transform = transforms.ToTensor())
        # Define the proportions for the split
        train_proportion = 0.8  # 80% for training
        val_proportion = 0.2  # 20% for validation

        # Calculate the sizes of training and validation sets based on the proportions
        train_size = int(train_proportion * len(train_dataset))
        val_size = len(train_dataset) - train_size
        
        # Use random_split to split the dataset
        self.train_dataset, self.val_dataset = utils.data.random_split(train_dataset, [train_size, val_size])
        

    def train_dataloader(self):
        return utils.data.DataLoader(dataset=self.train_dataset,batch_size=self.batch_size,shuffle=True,num_workers=self.num_workers,pin_memory=True)
    
    def val_dataloader(self):
        return utils.data.DataLoader(dataset=self.val_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
    
    def test_dataloader(self):
        return utils.data.DataLoader(dataset=self.test_dataset,batch_size=self.batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=True)
   

In [3]:
root = './data'
batch_size = 128
num_workers = 4
ds = MnistDataLoader(root, batch_size, num_workers)

In [4]:
# # Only for inspecting data 
# ds.prepare_data()
# ds.setup('train')
# for data, label in ds.train_dataloader().dataset:
#     print(data.shape,label)
#     break

In [5]:
class LSTM(pl.LightningModule):
    
    def __init__(self,input_size,hidden_size,output_shape,num_layers=2,bidirectional=False):
        super(LSTM,self).__init__()
        self.save_hyperparameters()
        
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size,hidden_size,num_layers,batch_first=True,dropout=0,bidirectional=bidirectional)
        
        self.fc = nn.Linear(hidden_size,output_shape)
        
        self.accuracy = torchmetrics.Accuracy(task="multiclass",num_classes=output_shape)
        self.f1_score = torchmetrics.F1Score(task="multiclass",num_classes=output_shape)
        
        self.lr= 1e-3
        
    def forward(self,x):
        h0 = torch.zeros((self.bidirectional*1+1)*self.num_layers, x.shape[0],self.hidden_size).to('mps')
        c0 = torch.zeros((self.bidirectional*1+1)*self.num_layers, x.shape[0],self.hidden_size).to('mps')
        x = x.squeeze(1)
        out, (hn,cn) = self.lstm(x,(h0,c0))
        return self.fc(out[:,-1,:])        
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, x_hat, y = self._common_step(batch,batch_idx)
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        
        self.log_dict({'train_loss':loss,
                      'train_accuracy':accuracy,
                      'train_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        if batch_idx % 100 == 0:
            x = x[:8]
            grid = torchvision.utils.make_grid(x.view(-1,1,28,28))
            self.logger.experiment.add_image('mnist_images',grid,self.global_step)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'val_loss':loss,
                      'val_accuracy':accuracy,
                      'val_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        loss, x_hat, y = self._common_step(batch,batch_idx)
        
        accuracy,f1_score = self.accuracy(x_hat,y), self.f1_score(x_hat,y)
        self.log_dict({'test_loss':loss,
                      'test_accuracy':accuracy,
                      'test_f1score':f1_score},prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def _common_step(self,batch,batch_index):
        x, y = batch
#         x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        loss = nn.functional.cross_entropy(x_hat,y)
        return loss , x_hat, y

    def predict_step(self,batch,batch_idx):
        x, y = batch
#         x = x.flatten(start_dim=1)
        x_hat = self.forward(x)
        pred = torch.argmax(x_hat,dim=1)
        return pred


    def configure_optimizers(self):
        return optim.Adam(self.parameters(),lr=self.lr)

In [6]:
# Hyperparametersa
input_size,hidden_size,output_shape = 28, 128, 10

batch = 256
num_epoch = 2
learning_rate = 1e-3

model = LSTM(input_size,hidden_size,output_shape)

In [7]:


logger = pl.loggers.TensorBoardLogger(save_dir='./log/', name='mnist_lstm', version=0.1)

profiler = pl.profilers.PyTorchProfiler(
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/',),
    schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
)

# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=3,
    monitor="val_f1score",
    mode="max",
    dirpath="checkpoints/rnn/",
    filename="mnist-lstm-{epoch}-{val_accuracy}",
)


trainer = pl.Trainer(
    logger=logger,
    accelerator='auto',
    devices=[0],
    min_epochs=1,
    max_epochs=100,
    precision='16-mixed',
    enable_model_summary=True,
#     profiler=profiler,
    callbacks=[pl.callbacks.EarlyStopping('val_loss',patience=6,verbose=True),checkpoint_callback],
#     default_root_dir="mnist_checkpoints/",
    enable_checkpointing  = True
    
)
if checkpoint_callback.best_model_path:
    trainer.fit(model, ds, ckpt_path=checkpoint_callback.best_model_path,)
else : 
    trainer.fit(model, ds)
trainer.validate(model, ds)
# trainer.test(model, ds)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name     | Type               | Params
------------------------------------------------
0 | lstm     | LSTM               | 212 K 
1 | fc       | Linear             | 1.3 K 
2 | accuracy | MulticlassAccuracy | 0     
3 | f1_score | MulticlassF1Score  | 0     
------------------------------------------------
214 K     Trainable params
0         Non-trainable params
214 K     Total params
0.857     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.219


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.091 >= min_delta = 0.0. New best score: 0.128


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.034 >= min_delta = 0.0. New best score: 0.094


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 0.082


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 0.067


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.063


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.060


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.051


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.049


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.042


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.039


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 6 records. Best score: 0.039. Signaling Trainer to stop.


Validation: 0it [00:00, ?it/s]

[{'val_loss': 0.015820208936929703,
  'val_accuracy': 0.9950833320617676,
  'val_f1score': 0.9950833320617676}]

In [8]:
trainer.test(model, ds,ckpt_path=checkpoint_callback.best_model_path)

Restoring states from the checkpoint path at /Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/applied_theories/lightning examples/checkpoints/rnn/mnist-lstm-epoch=17-val_accuracy=0.9899166822433472.ckpt
Loaded model weights from the checkpoint at /Users/pranavjha/Library/CloudStorage/GoogleDrive-pranajh7@gmail.com/My Drive/Projects/applied_theories/lightning examples/checkpoints/rnn/mnist-lstm-epoch=17-val_accuracy=0.9899166822433472.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.04293237626552582,
  'test_accuracy': 0.9872999787330627,
  'test_f1score': 0.9872999787330627}]

In [24]:
!tensorboard --logdir log/

TensorFlow installation not found - running with reduced feature set.
I0722 09:58:35.907079 6169096192 plugin.py:429] Monitor runs begin
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C
