### TODO:

* [x] Setup WandB
* [x] Split data
* Network:
  - [ ] Original
  * Mine
    1. [x] ReLU
    2. [x] Compression / Expansion
    3. [x] Batch Norm
    4. [x] Squeeze Excitation
    5. [x] Transposed vs Upsampled
    5. [x] EfficientNet Encoder
* Augmentation:
  - Preprocessing:
    * [x] Normalize

In [None]:
!python --version

## Imports

In [None]:
!git clone https://github.com/MicheleDamian/UNetBox.git

In [None]:
import sys

sys.path.insert(0, f'/kaggle/working/UNetBox')

In [None]:
import os
import math
import torch
import cv2
import multiprocessing
import wandb
import timm

import albumentations as A
import pandas as pd
import numpy as np

import pytorch_lightning as pl

import torch.nn as nn
import torch.nn.functional as F

from itertools import product, chain
from datetime import datetime, timezone
from dataclasses import dataclass, field
from functools import partial
from path import Path

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from tqdm.auto import tqdm

from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from albumentations.pytorch.transforms import ToTensorV2
from torchvision.ops import sigmoid_focal_loss
from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
from pytorch_lightning.loggers import WandbLogger

from sklearn.model_selection import StratifiedKFold

from unetbox.net import UNetBox

## Setup

In [None]:
!wandb login {secret_value}

In [None]:
@dataclass
class Config:
    project: str                 = 'Unet Ablation'
    session_id: str              = None
    seed: int                    = 2023
    n_folds: int                 = 3
    learning_rate: float         = 2e-3       # This can be changed after running the Tuner
    kaggle_path: Path            = Path('./google-research-identify-contrails-reduce-global-warming')
    input_path: Path             = Path('./dataset')
    output_path: Path            = Path('.')
    n_channels: int              = 3          # Number of channels in the images in the dataset
    timeindex: tuple[int]        = (4, )
    input_size: tuple[int]       = (256, 256) # The size of the first layer's input
    data_mean: tuple[float]      = (275.65, 0.98859, -2.8341)
    data_std: tuple[float]       = (14.714, 1.549, 0.93514)
    batch_size: int              = 64
    accumulate_grad: int         = 1
    num_epochs: int              = 15
    model_params: dict[str, int] = field(default_factory=dict)
    align_motion: str            = None
        
Config.model_params = {
    'depth': 4, 
    'expansion': 16, 
    'base_chn': Config.n_channels * len(Config.timeindex), 
    'activation': nn.SiLU, 
    'encoder': 'default',
    'expansion_layer': True,
    'norm_layer': True,
    'convup_layer': True,
    'se_block': True
} 

Config.session_id = Config.session_id or datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')

In [None]:
pl.seed_everything(Config.seed, workers=True)

## Helpers

In [None]:
class RLE():
    @staticmethod
    def encode(mask):
        
        m = np.zeros(mask.size + 2, dtype=mask.dtype)
        m[1:-1] = mask.T.flatten()
        
        start = np.where(m[:-1] != m[1:])[0] + 1
        length = start[1:] - start[:-1]
        return list(zip(start, length))[::2]
        
        
    @staticmethod
    def decode(rle, height, width):
        
        mask = np.zeros(height * width, dtype=np.uint8)

        if type(rle) != list: rle = []

        for s, l in rle: mask[s-1:s+l-1] = 1

        return mask.reshape(height, width)

In [None]:
def get_contrail_size(paths):
    
    contrail_size = []

    for mask_path in tqdm(paths):

        mask = np.load(mask_path / 'human_pixel_masks.npy')
        contrail_size.append(mask.sum())
        
    return contrail_size

In [None]:
def create_ash_image(basepath, timeindex=4, clip=False):
    
    band_12_path = basepath / 'band_15.npy'
    band_11_path = basepath / 'band_14.npy'
    band_8_path = basepath / 'band_11.npy'
    
    band_12 = np.load(band_12_path)[..., timeindex].astype(np.float32)
    band_11 = np.load(band_11_path)[..., timeindex].astype(np.float32)
    band_8 = np.load(band_8_path)[..., timeindex].astype(np.float32)
    
    chn_0 = band_11
    chn_1 = band_11 - band_8
    chn_2 = band_12 - band_11
        
    if clip:
        chn_0 = chn_0.clip(min=243, max=303)
        chn_1 = chn_1.clip(min=-4, max=5)
        chn_2 = chn_2.clip(min=-4, max=2)
        
    return np.stack((chn_0, chn_1, chn_2), axis=-1)

In [None]:
class ContrailsDataset(Dataset):
    def __init__(self, 
        paths, 
        transforms, 
        testset=False, 
        timeindex=[4], 
        align=None,
        storage=None):
        super().__init__()
        
        self.paths = paths
        self.transforms = transforms
        self.testset = testset
        self.timeindex = timeindex
        self.align = align
        self.storage = storage
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        
        dst_dir = self.storage or Path('/')
        flow_path = dst_dir / 'train' / self.paths[idx].name / 'flow.npy'
        
        flow = np.load(flow_path) if flow_path.exists() else None
            
        input = create_ash_image(self.paths[idx], timeindex=self.timeindex)

        if self.align: 
            for t in range(len(self.timeindex) - 1):
                input[..., t, :], out_flow = align_channels(
                    input[..., t, :], 
                    input[..., -1, :], 
                    flow=flow,
                    motion=self.align, 
                    iterations=10
                )
                         
        if self.storage and flow is None: 
            if not flow_path.parent.exists(): flow_path.parent.makedirs()
            np.save(flow_path, out_flow)
        
        if len(input.shape) > 3: input = input.reshape((*input.shape[:2], -1))
        
        if not self.testset:
            mask_path = self.paths[idx] / 'human_pixel_masks.npy'
            mask = np.load(mask_path).squeeze().astype(np.float32)
        else:
            mask = np.zeros(input.shape[:2], dtype=np.float32)
        
        if self.transforms:
            trs = self.transforms(image=input, mask=mask)
            input, mask = trs['image'], trs['mask']
                
        return input, mask[None, ...]

In [None]:
class DataModule(pl.LightningDataModule):

    def __init__(self, 
        train_df, 
        valid_df,
        input_size, 
        data_mean, 
        data_std, 
        num_cpus=os.cpu_count(),
        batch_size=32,
        transforms=None,
        timeindex=(4,),
        align=None,
        storage=None):
        super().__init__()

        self.train_df, self.valid_df = train_df, valid_df
        self.batch_size = batch_size
        self.input_size = input_size
        self.num_cpus = num_cpus
        self.timeindex = timeindex
        self.align = align
        self.storage = storage
        
        self.data_mean, self.data_std = data_mean * len(timeindex), data_std * len(timeindex)
        
        self.transforms_train = transforms or A.Compose([
            A.Normalize(self.data_mean, self.data_std, max_pixel_value=1.0),
            ToTensorV2()
        ])
        self.transforms_valid = A.Compose([
            A.Normalize(self.data_mean, self.data_std, max_pixel_value=1.0),
            ToTensorV2()
        ])

    def train_dataloader(self):
        trainset = ContrailsDataset(
            self.train_df['path'].values, 
            self.transforms_train, 
            timeindex=self.timeindex,
            align=self.align,
            storage=self.storage
        )
        num_batches = len(trainset) // self.batch_size
        k_fold = StratifiedKFold(n_splits=num_batches, shuffle=True)
        batch_sampler = list(fold for _, fold in k_fold.split(self.train_df, self.train_df['bin']))
        return DataLoader(trainset, batch_sampler=batch_sampler, num_workers=self.num_cpus)

    def val_dataloader(self):
        valset = ContrailsDataset(
            self.valid_df['path'].values, 
            self.transforms_valid, 
            timeindex=self.timeindex,
            align=self.align,
            storage=self.storage
        )
        return DataLoader(valset, batch_size=self.batch_size, num_workers=self.num_cpus)

In [None]:
class Model(pl.LightningModule):
    def __init__(self, model, total_steps, learning_rate, criterion):
        super().__init__()

        self.model = UNetBox(**model) if isinstance(model, dict) else model
        self.criterion = criterion
        self.learning_rate_max = learning_rate
        self.total_steps = total_steps
        
        self.save_hyperparameters()
         
        self.valid_loss = 0.
        self.valid_num_batches = 0
        
    def forward(self, x):
        return self.model.forward(x)

    def configure_optimizers(self):

        optimizer = Adam(
            filter(lambda p: p.requires_grad, self.parameters()), 
            lr=self.learning_rate_max / 5e1
        )
        
        scheduler = OneCycleLR(
            optimizer,
            max_lr=self.learning_rate_max,
            div_factor=10.,
            final_div_factor=1.,
            total_steps=self.total_steps
        )

        config = {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'strict': False
            }
        }

        return config

    def training_step(self, batch, batch_idx):

        inputs, labels = batch

        # Forward pass
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)

        # Add training metrics
        self.log(f'train/{self.criterion.func.__name__}', loss, logger=True)

        return loss
            
    def validation_step(self, batch, batch_idx):
        if not self.logger or not isinstance(self.logger, WandbLogger): return
        
        inputs, labels = batch

        # Forward pass
        outputs = self.model(inputs)
        
        self.valid_loss += self.criterion(outputs, labels).cpu().item()
        self.valid_num_batches += 1
        
    def on_validation_epoch_end(self):
        
        # Add validation metrics
        loss = self.valid_loss / max(1, self.valid_num_batches)
        self.log(f'validation/{self.criterion.func.__name__}', loss, logger=True)
        
        self.valid_loss = 0.
        self.valid_num_batches = 0

## Losses

### Continuous Dice Loss

In [None]:
def continuous_dice_loss(inputs, targets, reduction='none', continuous=True, dim=-1):

    inputs = inputs.sigmoid()
    
    if not continuous: inputs = inputs > .5
    
    # Flatten label and prediction tensors
    start_dim = 1 if reduction == 'none' else 0
    inputs = inputs.flatten(start_dim=start_dim)
    targets = targets.flatten(start_dim=start_dim)
    
    intersection = (inputs * targets).sum(dim=dim)
    union = inputs.sum(dim=dim) + targets.sum(dim=dim)
    
    cDC = 1. - 2. * intersection / union
    
    return cDC

## Transforms

In [None]:
transforms = A.Compose([
    A.Normalize(
        Config.data_mean * len(Config.timeindex), 
        Config.data_std * len(Config.timeindex), 
        max_pixel_value=1.0),
    ToTensorV2()
])

## Create Dataset

In [None]:
data_paths = (Config.kaggle_path / 'train').listdir()
data_paths += (Config.kaggle_path / 'validation').listdir()

Binning:

In [None]:
bins = {
    0: (0, 0), 
    1: (1, 99), 
    2: (100, 328), 
    3: (329, 907), 
    4: (908, 2**16)
}

binning = lambda x: max(k if v[0] <= x <= v[1] else 0 for k, v in bins.items())

In [None]:
contrail_size = get_contrail_size(data_paths)

data_df = pd.DataFrame(data=zip(data_paths, contrail_size), columns=['path', 'contrail_size'])
data_df['bin'] = data_df['contrail_size'].map(binning)

In [None]:
loss_func = partial(sigmoid_focal_loss, reduction='mean')

## LR Finder

In [None]:
if not Config.learning_rate:
    
    data = DataModule(
        data_df, 
        data_df,
        Config.input_size, 
        Config.data_mean, 
        Config.data_std, 
        batch_size=Config.batch_size,
        transforms=transforms,
        timeindex=Config.timeindex,
        align=Config.align_motion
    )
    
    total_steps = math.ceil(len(data.train_dataloader()) / Config.accumulate_grad) * Config.num_epochs
    model = Model(Config.model_params, total_steps, 1., loss_func)
    
    trainer = pl.Trainer(
        precision='16-mixed',
        accelerator='gpu',
        devices=1,
        max_epochs=Config.num_epochs,
        accumulate_grad_batches=Config.accumulate_grad
    )
    tuner = pl.tuner.tuning.Tuner(trainer)
    lr_finder = tuner.lr_find(model=model, datamodule=data, min_lr=1e-5)

    fig = lr_finder.plot(suggest=True)
    fig.show()

## Train

In [None]:
k_fold = StratifiedKFold(n_splits=Config.n_folds, shuffle=True)
folds = enumerate(k_fold.split(data_df, data_df['bin']))

In [None]:
pbar = tqdm(
    iterable=folds,
    desc='Fold',
    total=Config.n_folds,
    position=0
)

for fold, (train_index, valid_index) in pbar:
        
    train_df, valid_df = data_df.iloc[train_index], data_df.iloc[valid_index]
    
    data = DataModule(
        train_df, 
        valid_df,
        Config.input_size, 
        Config.data_mean, 
        Config.data_std, 
        batch_size=Config.batch_size,
        transforms=transforms,
        timeindex=Config.timeindex,
        align=Config.align_motion
    )
    
    total_steps = math.ceil(len(data.train_dataloader()) / Config.accumulate_grad) * Config.num_epochs
    model = Model(Config.model_params, total_steps, Config.learning_rate, loss_func)

    # This is to close the previous run and start a new one; 
    # wandb_logger.finalize('success') doesn't work as expected
    wandb.finish()
    
    wandb_logger = WandbLogger(
        project=Config.project,
        group=Config.session_id,
        name=f'{Config.session_id}_{fold}',
        log_model='all'
    )
    
    wandb_logger.watch(model, log='all')
    
    trainer = pl.Trainer(
        precision='16-mixed',
        accelerator='gpu',
        devices=1,
        max_epochs=Config.num_epochs,
        accumulate_grad_batches=Config.accumulate_grad,
        deterministic=True,
        callbacks=[
            LearningRateMonitor(logging_interval='step'), 
            ModelCheckpoint(monitor=f'validation/{model.criterion.func.__name__}', mode='min', save_top_k=3)
        ],
        logger=wandb_logger
    )

    trainer.fit(model=model, datamodule=data)