[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/blob/master/2022/Lab1b_lightning.ipynb)

In [1]:
# Download data, unzip
!gdown https://drive.google.com/uc?id=1XxBBah4J3wmSAMFq8lBFc06vGWFiy1TZ
!unzip GBDA2020_ML1.zip

In [None]:
%pip install pytorch-lightning

In [2]:
# Read data
DATA_ROOT = "/labs/lab1/data/partB"

CLASS_NAMES = [
    "Alfalfa",
    "Corn-notill",
    "Corn-mintill",
    "Corn",
    "Grass-pasture",
    "Grass-trees",
    "Grass-pasture-mown",
    "Hay-windrowed",
    "Oats",
    "Soybean-notill",
    "Soybean-mintill",
    "Soybean-clean",
    "Wheat",
    "Woods",
    "Buildings-Grass-Trees-Drives",
    "Stone-Steel-Towers"
]

## Handle data. Datasets and DataLoaders

In [3]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from sklearn.preprocessing import StandardScaler
import numpy as np
import os
from copy import copy

class IndianPinesDataset(Dataset):    
    def __init__(self, data_root, transforms=[]):
        '''
        Dataset constructor
        '''
        super().__init__()
        self.transforms: list = copy(transforms)
        self._build(data_root)
        
    def _build(self, data_root) -> None:
        '''
        Helper method to retrieve all samples
        '''
        img = np.load(os.path.join(data_root, "indianpinearray.npy"))
        gt_img = np.load(os.path.join(data_root, "IPgt.npy"))
        
        valid_mask = gt_img > 0

        self.X = img[valid_mask].reshape(-1, 200).astype(np.float32)
        self.y = gt_img[valid_mask].reshape(-1).astype(int) - 1 # "-1" is to compensate for "no_data" class "0" in original_data
    
    def apply_std_scaler(self, indices):
        '''
        Perform std scaling normalization given a list of indices to compute the transform
        '''
        scaler = StandardScaler()
        scaler.fit(self.X[np.array(indices)])
        self.X = scaler.transform(self.X)
        
    def __getitem__(self, index):
        '''
        Method to retrieve samples
        '''
        X, y = self.X[index], self.y[index]
        for T in self.transforms:
            X, y = T(X, y)
        return X, y
    
    def __len__(self) -> int:
        '''
        Return the total number of samples in dataset
        '''
        return len(self.X)
        


In [None]:
# Initialize a dataset instance
dset = IndianPinesDataset(DATA_ROOT)
print("Samples in dataset: ", len(dset))


# Split data into train (70%) and validation (30%) sets
train_dset, val_dset = random_split(dset, [int(0.7*len(dset)), len(dset)-int(0.7*len(dset))], generator=torch.Generator().manual_seed(2022))

# Use only train dset to compute a std scale normalization transform. Apply for both train and validation sets
print("Max value for the first sample in 'val' (before scaling): ", val_dset[0][0].max())
dset.apply_std_scaler(train_dset.indices)
print("Max value for the first sample in 'val' (after scaling)\t: ", val_dset[0][0].max())


#  Initialize dataloaders (batching / tensor-casting / shuffling / etc.)
BATCH_SIZE = 64
train_dloader = DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, num_workers=24)
val_dloader = DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False, num_workers=24)

# Inspect a single batch returned form train dataloader
for s in train_dloader:
    X, y = s
    print(f"Sample's X type: {type(X)}, dtype: {X.dtype}, shape: {X.size()}")
    print(f"Sample's y type: {type(y)}, dtype: {y.dtype}, shape: {y.size()}")
    break

## Define a MLP model

In [None]:

import pytorch_lightning as pl
import torch.nn.functional as F
import torch
from torch import nn
from torch.optim import Adam
from torchmetrics import Accuracy

class MLP(pl.LightningModule):
    def __init__(self, in_features: int, num_classes: int):
        '''
        Model constructor (includes MLP's architecture, metrics)
        '''
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

        self.val_acc = Accuracy()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward-pass
        '''
        return self.model(x)

    def training_step(self, batch, batch_idx):
        '''
        Training logic
        '''
        X, y = batch

        preds = self.model(X)

        loss = F.nll_loss(torch.log_softmax(preds, dim=-1), y)

        l1_normalization = torch.norm(self.model[0].weight.data, p=1)
        composite_loss = loss + 0.01*l1_normalization
        
        #  Log to selected logger (def: Tensorboard)
        self.log("loss/train", loss, on_epoch=True, on_step=False)
        self.log("normalization_loss/train", l1_normalization, on_epoch=True, on_step=False)
        self.log("composite_loss/train", composite_loss, on_epoch=True, on_step=False)
        
        return composite_loss
    
    def configure_optimizers(self):
        '''
        Select an optimization algorithm + parameters (learning rate, ...)
        '''
        return Adam(self.parameters(), lr=1e-4)

    def validation_step(self, batch, batch_idx):
        '''
        Validation logic
        '''
        X, y = batch

        # Predict with model
        preds = self.model(X)

        # Compute *interesting* metrics
        loss = F.nll_loss(torch.log_softmax(preds, dim=-1), y)
        self.log("loss/val", loss, on_epoch=True, on_step=False)

        self.val_acc.update(preds, y)
        self.log("accuracy/val", self.val_acc, on_epoch=True, on_step=False)

model = MLP(200, 16)
print("MLP's output shape: ", model(next(iter(val_dloader))[0]).size())



## Training 

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

callbacks = [
    EarlyStopping("accuracy/val", patience=3, mode='max'),
    ModelCheckpoint(monitor="accuracy/val", save_last=True, save_top_k=2)]

# Define a Trainer instance and train/validate the model
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    min_epochs=50,
    max_epochs=1000,
    check_val_every_n_epoch=2,
    callbacks=callbacks)

trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs