In [1]:
from pathlib import Path
print(Path.cwd())

/mnt/data/Taylor/notebooks/KOM-kelp-aco


In [2]:
%set_env WANDB_NOTEBOOK_NAME "/mnt/data/Taylor/notebooks/KOM-kelp-aco/UNet.ipynb"

env: WANDB_NOTEBOOK_NAME="/mnt/data/Taylor/notebooks/KOM-kelp-aco/UNet.ipynb"


In [3]:
import os
import random
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional, Union, Any
import shutil

import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics.functional as fm
import wandb
from PIL import Image
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchmetrics.classification import Dice
from torchvision import transforms as t
from torchvision.transforms.functional import pad, pil_to_tensor
from wandb import AlertLevel
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from lightning.pytorch.profilers import SimpleProfiler

from einops import rearrange 
from unified_focal_loss import AsymmetricUnifiedFocalLoss, FocalTverskyLoss
from datamodule import DataModule

# Config

In [4]:
# CHECKPOINT OPTIONS
CHECKPOINT_DIR = "./checkpoints"
NAME = "UNet"
PROJECT_NAME = "kom-kelp-pa-aco-rgbi"

# DATASET OPTIONS
DATA_DIR = "/home/taylor/data/KP-ACO-RGBI-Nov2023/"
NUM_WORKERS = os.cpu_count() // 2
PIN_MEMORY = True
NUM_CLASSES = 3
IGNORE_INDEX = 2
BATCH_SIZE = 2
FILL_VALUE = 0
NUM_BANDS = 4

# MODEL OPTIONS
LR = 0.0003
ALPHA = 0.8
GAMMA = 0.5
WEIGHT_DECAY = 0.0001
MAX_EPOCHS = 10
PRECISION = "16-mixed"
SYNC_BATCHNORM = True
IMG_SHAPE = 1024
DROPOUT=0.5
WARMUP_PERIOD=1./MAX_EPOCHS

In [5]:
pl.seed_everything(0, workers=True)

Seed set to 0


0

In [6]:
# Make checkpoint directory
Path(CHECKPOINT_DIR, NAME).mkdir(exist_ok=True, parents=True)

# Define model

In [12]:
# Created by: Taylor Denouden
# Organization: Hakai Institute

class UNet(pl.LightningModule):
    def __init__(self, num_classes: int = 2, ignore_index: Optional[int] = None, lr: float = 0.35,
                 weight_decay: float = 0, loss_delta: float = 0.7, loss_gamma: float = 4.0 / 3.0, max_epochs: int = 100, 
                 warmup_period:float=0.3):
        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.lr = lr
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        self.warmup_period = warmup_period
        
        self.loss_delta = loss_delta
        if self.ignore_index is not None:
            self.n = num_classes - 1
        else:
            self.n = num_classes
        
        self.model = smp.UnetPlusPlus('resnet50', in_channels=NUM_BANDS, classes=self.n, 
                                      decoder_attention_type="scse")
        for p in self.model.parameters():
            p.requires_grad = True
        self.model = torch.compile(self.model, fullgraph=False, mode="max-autotune")
        self.loss_fn = AsymmetricUnifiedFocalLoss(delta=loss_delta, gamma=loss_gamma)
        # self.loss_fn = FocalTverskyLoss(delta=loss_delta, gamma=loss_gamma)

    @property
    def example_input_array(self) -> Any:
        return torch.ones((BATCH_SIZE, NUM_BANDS, IMG_SHAPE, IMG_SHAPE), device=self.device, dtype=self.dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model.forward(x)

    def remove_ignore_pixels(self, logits, y):
        mask = (y != self.ignore_index)
        return logits[mask], y[mask]

    def training_step(self, batch, batch_idx):
        return self._phase_step(batch, batch_idx, phase="train")

    def validation_step(self, batch, batch_idx):
        return self._phase_step(batch, batch_idx, phase="val")

    def test_step(self, batch, batch_idx):
        return self._phase_step(batch, batch_idx, phase="test")

    def _phase_step(self, batch, batch_idx, phase):
        x, y = batch
        
        logits = self.forward(x)

        has_wandb_logger = isinstance(trainer.logger, pl.loggers.WandbLogger)
        if phase == "val" and batch_idx == 0 and has_wandb_logger:
            class_labels = {
                0: "background",
                1: "kelp"
            }
            images = [im for im in x[:,:3,:,:]]
            masks = [
                {
                  "predictions": {
                    "mask_data": pr.detach().cpu().numpy(),
                    "class_labels": class_labels
                  },
                  "ground_truth": {
                    "mask_data": gt.detach().cpu().numpy(),
                    "class_labels": class_labels
                  },
                } for gt, pr in zip(y, logits.argmax(dim=1))
            ]
            self.logger.log_image("Predictions", images=images, masks=masks)
        
        # Flatten and eliminate ignore class instances
        y = rearrange(y, 'b h w -> (b h w)').long()
        logits = rearrange(logits, 'b c h w -> (b h w) c')
        if self.ignore_index is not None:
            logits, y = self.remove_ignore_pixels(logits, y)
        
        if len(y) == 0:
            print("0 length y!")
            return 0
        
        probs = torch.softmax(logits, dim=1)
        
        loss = self.loss_fn(probs, y.long())

        accuracy = fm.accuracy(probs, y, task="multiclass", num_classes=self.n, average='macro')
        miou = fm.jaccard_index(probs, y, task="multiclass", num_classes=self.n, average='macro')
        ious = fm.jaccard_index(probs, y, task="multiclass", num_classes=self.n, average='none')
        recalls = fm.recall(probs, y, task="multiclass", num_classes=self.n, average='none')
        precisions = fm.precision(probs, y, task="multiclass", num_classes=self.n, average='none')
        f1s = fm.f1_score(probs, y, task="multiclass", num_classes=self.n, average='none')

        is_training = phase == "train"
        self.log(f"{phase}/loss", loss, prog_bar=is_training)
        self.log(f"{phase}/miou", miou, prog_bar=is_training)
        self.log(f"{phase}/accuracy", accuracy)

        for c in range(self.n):
            clsx = f"cls{c + 1}" if self.ignore_index and c >= self.ignore_index else f"cls{c}"
            self.log(f"{phase}/{clsx}_iou", ious[c], prog_bar=is_training)
            self.log_dict({
                f"{phase}/{clsx}_recall": recalls[c],
                f"{phase}/{clsx}_precision": precisions[c],
                f"{phase}/{clsx}_f1": f1s[c],
            })

        return loss
    
    def configure_optimizers(self):
        """Init optimizer and scheduler"""
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()),
                                      lr=self.lr, weight_decay=self.weight_decay, amsgrad=True)
        
        steps = self.trainer.estimated_stepping_batches
        warmup_steps = steps*self.warmup_period
        
        linear_warmup_sch = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=warmup_steps)
        cosine_sch = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(steps-warmup_steps))
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [linear_warmup_sch, cosine_sch], milestones=[warmup_steps])
        
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
        

# Setup Callbacks and Trainer

In [13]:
ENABLE_LOGGING = True

profiler = SimpleProfiler(dirpath=".", filename="perf_logs")

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val/miou",
    mode="max",
    filename="{val/miou:.4f}_{epoch}",
    save_top_k=1,
    save_last=True,
    save_on_train_epoch_end=False,
    every_n_epochs=1,
    verbose=False,
)

if ENABLE_LOGGING:
    logger = WandbLogger(
        name=NAME,
        project=PROJECT_NAME,
        save_dir=CHECKPOINT_DIR,
        log_model=True,
    )
    logger.experiment.config["batch_size"] = BATCH_SIZE
    # logger.experiment.config["dropout"] = DROPOUT
else:
    logger = pl.loggers.CSVLogger(save_dir="/tmp/")

trainer = pl.Trainer(
    # profiler=profiler,
    # overfit_batches=10,
    # log_every_n_steps=3,
    # limit_train_batches=3,
    # limit_val_batches=3,
    # accelerator='cpu',
    # fast_dev_run=True,
    deterministic=True,
    benchmark=True,
    max_epochs=MAX_EPOCHS,
    precision=PRECISION,
    logger=logger,
    gradient_clip_val=0.5,
    accumulate_grad_batches=8,
    callbacks=[
        checkpoint_callback,
        pl.callbacks.LearningRateMonitor(),
#         Finetuning(unfreeze_at_epoch=FINETUNE_EPOCH)
    ],
)

/home/taylor/miniforge3/envs/kom/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:668: You passed `deterministic=True` and `benchmark=True`. Note that PyTorch ignores torch.backends.cudnn.deterministic=True when torch.backends.cudnn.benchmark=True.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


# Load dataset

In [14]:
data_module = DataModule(
    DATA_DIR,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=True,
    num_classes=NUM_CLASSES,
    batch_size=BATCH_SIZE,
    fill_value=FILL_VALUE,
    tile_size=IMG_SHAPE,
)

# Load model

In [15]:
# init the model directly on the device and with parameters in half-precision
model = UNet(
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
    lr=LR,
    loss_delta=ALPHA,
    loss_gamma=GAMMA,
    weight_decay=WEIGHT_DECAY,
    max_epochs=MAX_EPOCHS,
    # dropout=DROPOUT,
    warmup_period=WARMUP_PERIOD,
)

# TRAIN

In [16]:
# %%debug

torch.set_float32_matmul_precision("medium")

try:
    trainer.fit(model, datamodule=data_module)
    
    if not trainer.fast_dev_run and ENABLE_LOGGING:
        best_miou = checkpoint_callback.best_model_score.detach().cpu()
        print("Best mIoU:", best_miou)
        wandb.alert(
            title="Training complete",
            text=f"Best mIoU: {best_miou}",
            level=AlertLevel.INFO,
        )

finally:
    if ENABLE_LOGGING:
        wandb.finish()
    else:
        shutil.rmtree(logger.log_dir)

/home/taylor/miniforge3/envs/kom/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:634: Checkpoint directory /mnt/data/Taylor/notebooks/KOM-kelp-aco/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name    | Type                       | Params | In sizes | Out sizes
------------------------------------------------------------------------------
0 | model   | OptimizedModule            | 51.1 M | ?        | ?        
1 | loss_fn | AsymmetricUnifiedFocalLoss | 0      | ?        | ?        
------------------------------------------------------------------------------
51.1 M    Trainable params
0         Non-trainable params
51.1 M    Total params
204.515   Total estimated model params size (MB)


Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_steps=1` reached.


FileNotFoundError: [Errno 2] No such file or directory: '/tmp/lightning_logs/version_2'

In [4]:
!jupyter nbconvert --to script "UNet.ipynb"

[NbConvertApp] Converting notebook UNet.ipynb to script
[NbConvertApp] Writing 9921 bytes to UNet.py
