In [None]:
!pip install pytorch_lightning

In [1]:
!rm -r tb_logs

In [2]:
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import transforms as T
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
from torchmetrics import Metric
import random

# PyTorch Lightning Model Class

In [3]:
class LightningModel(pl.LightningModule):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def _forward_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        scores = self.forward(x)
        loss = nn.functional.cross_entropy(scores, y)
        return loss, scores, y
    
    def training_step(self, batch, batch_idx):
        loss, scores, y = self._forward_step(batch, batch_idx)
        accuracy = self.accuracy(scores, y)
        f1score = self.f1score(scores, y)
        self.log_dict({"train_loss": loss, "train_accuracy": accuracy,
                       "train_f1_score": f1score}, on_epoch=True, prog_bar=True)
        if batch_idx % 10 == 0:
            images, _ = batch
            indices = torch.randperm(len(images))[:8]
            images = images[indices]
            grid = torchvision.utils.make_grid(images.view(-1, 1, 28, 28))
            self.logger.experiment.add_image(f"MNIST Input Image - {batch_idx}", grid, 0)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, scores, y = self._forward_step(batch, batch_idx)
        self.log(f"Validation Loss {batch_idx}", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, scores, y = self._forward_step(batch, batch_idx)
        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

### Loading Traning and Validation Data

In [22]:
my_transforms = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])
train_data = torchvision.datasets.MNIST('./data/', train=True, transform=my_transforms, download=True)
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=64, num_workers=1)
val_data = torchvision.datasets.MNIST('./data/', train=False, transform=my_transforms, download=True)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=64, num_workers=1)

In [23]:
for x, y in train_data_loader:
    print(x.shape, y.shape)
    break
for x, y in val_data_loader:
    print(x.shape, y.shape)
    break

torch.Size([64, 784]) torch.Size([64])
torch.Size([64, 784]) torch.Size([64])


### Initializing Lightning Trainer

In [24]:
tb_logger = TensorBoardLogger("tb_logs", name='mnist_model_v0')
trainer = pl.Trainer(max_epochs=8, logger=tb_logger)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [25]:
model = LightningModel(784, 10)
trainer.fit(model, train_data_loader, val_data_loader)

Missing logger folder: tb_logs/mnist_model_v0

  | Name     | Type               | Params
------------------------------------------------
0 | fc1      | Linear             | 100 K 
1 | fc2      | Linear             | 8.3 K 
2 | fc3      | Linear             | 650   
3 | accuracy | MulticlassAccuracy | 0     
4 | f1score  | MulticlassF1Score  | 0     
------------------------------------------------
109 K     Trainable params
0         Non-trainable params
109 K     Total params
0.438     Total estimated model params size (MB)


Sanity Checking: |                                                                                 | 0/? [00:0…

Training: |                                                                                        | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

`Trainer.fit` stopped: `max_epochs=8` reached.


### Initializing Lightning Trainer with GPU

In [None]:
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=8)

In [None]:
model = LightningModel(784, 10)
trainer.fit(model, train_dataloader=train_data_loader, val_dataloader=val_data_loader)

### Training with Lightning Data Module

In [8]:
class DataModule(pl.LightningDataModule):
  def __init__(self, data_dir, batch_size, num_workers):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.my_transforms = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])

  def prepare_data(self):
    torchvision.datasets.MNIST(self.data_dir, download=True, train=True)
    torchvision.datasets.MNIST(self.data_dir, download=True, train=False)

  def setup(self, stage):
    entire_dataset = torchvision.datasets.MNIST(self.data_dir, train=True, download=False,
                                                transform=self.my_transforms)
    self.train_ds, self.test_ds = torch.utils.data.random_split(entire_dataset, [50_000, 10_000])

    self.val_ds = torchvision.datasets.MNIST(self.data_dir, train=False, download=False,
                                             transform=self.my_transforms)

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size,
                                       num_workers=self.num_workers)
    
  def test_dataloader(self):
    return torch.utils.data.DataLoader(self.test_ds, batch_size=self.batch_size,
                                       num_workers=self.num_workers)
  
  def val_dataloader(self):
    return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size,
                                       num_workers=self.num_workers)

In [9]:
dm = DataModule('./', 64, 2)

In [10]:
tb_logger = TensorBoardLogger("tb_logs", name='mnist_dm_model_v0')
trainer = pl.Trainer(max_epochs=8, logger=tb_logger)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
model = LightningModel(784, 10)
trainer.fit(model, dm)


  | Name     | Type               | Params
------------------------------------------------
0 | fc1      | Linear             | 100 K 
1 | fc2      | Linear             | 8.3 K 
2 | fc3      | Linear             | 650   
3 | accuracy | MulticlassAccuracy | 0     
4 | f1score  | MulticlassF1Score  | 0     
------------------------------------------------
109 K     Trainable params
0         Non-trainable params
109 K     Total params
0.438     Total estimated model params size (MB)


Sanity Checking: |                                                                                 | 0/? [00:0…

/Users/atifadib/opt/anaconda3/envs/torch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |                                                                                        | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

Validation: |                                                                                      | 0/? [00:0…

`Trainer.fit` stopped: `max_epochs=8` reached.
