# **OSSCA Pytorch Lightning Assignment**

In [14]:
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(0, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [15]:
class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()
        
        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        
        # Hardcode some dataset specific sttributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        c, w, h = self.dims
        self.transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,)),
        ])
        
        # Define Pytorch model
        self.model = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(c*w*h, hidden_size),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_size, hidden_size),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_size, hidden_size)
        )
        
        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [17]:
model = LitMNIST()
trainer = Trainer(
    gpus=AVAIL_GPUS,
    max_epochs=10,
    progress_bar_refresh_rate=20,
)
trainer.fit(model)

  rank_zero_deprecation(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name     | Type       | Params
----------------------------------------
0 | model    | Sequential | 58.6 K
1 | accuracy | Accuracy   | 0     
----------------------------------------
58.6 K    Trainable params
0         Non-trainable params
58.6 K    Total params
0.234     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [18]:
# model test
trainer.test()


  rank_zero_warn(
Restoring states from the checkpoint path at /home/matrix/Desktop/Study/ossca/pl/lightning_logs/version_1/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from checkpoint at /home/matrix/Desktop/Study/ossca/pl/lightning_logs/version_1/checkpoints/epoch=9-step=8600.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.9682999849319458
        val_loss            0.10068415105342865
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.10068415105342865, 'val_acc': 0.9682999849319458}]

In [19]:
trainer.fit(model)


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name     | Type       | Params
----------------------------------------
0 | model    | Sequential | 58.6 K
1 | accuracy | Accuracy   | 0     
----------------------------------------
58.6 K    Trainable params
0         Non-trainable params
58.6 K    Total params
0.234     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

In [20]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/