# U-Net notebook for Palaeochannels dataset
The input for training:

-geojson file of:
    train tiles
    train area
    validation(test) tiles
    validation(test) area 
    ground truth

- tif file of:
    image


In [1]:
from dotenv import load_dotenv
load_dotenv()

import os
from pathlib import Path
workdir = Path(os.getenv("WORKDIR", '..'))
scratchdir = Path(os.getenv("SCRATCHDIR", '..')) #checkpoint location


In [2]:
from segmentation_models_pytorch.decoders.unet import Unet
from torchgeo.models.resnet import ResNet50_Weights, resnet50
from torchinfo import summary
import sys
sys.path.append(str(workdir))
from esa_cls_palaeo.dataset import SingleRasterPalaeochannelDataset

import geopandas as gpd
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset
import kornia.augmentation as K

from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import segmentation_models_pytorch as smp
import os
from lightning.pytorch.loggers import TensorBoardLogger

from lightning.pytorch import LightningModule
from segmentation_models_pytorch.decoders.unet import Unet
from torchgeo.models.resnet import ResNet50_Weights, resnet50
from torchmetrics.classification import BinaryJaccardIndex, BinaryPrecision, BinaryRecall, BinaryPrecisionRecallCurve

from torchvision.utils import make_grid, draw_segmentation_masks

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = Unet(encoder_name="resnet50", in_channels=3)
summary(model.encoder, input_size=(32, 3, 256, 256))

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /home/vscode/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 94.7MB/s]


Layer (type:depth-idx)                   Output Shape              Param #
ResNetEncoder                            [32, 3, 256, 256]         --
├─Conv2d: 1-1                            [32, 64, 128, 128]        9,408
├─BatchNorm2d: 1-2                       [32, 64, 128, 128]        128
├─ReLU: 1-3                              [32, 64, 128, 128]        --
├─MaxPool2d: 1-4                         [32, 64, 64, 64]          --
├─Sequential: 1-5                        [32, 256, 64, 64]         --
│    └─Bottleneck: 2-1                   [32, 256, 64, 64]         --
│    │    └─Conv2d: 3-1                  [32, 64, 64, 64]          4,096
│    │    └─BatchNorm2d: 3-2             [32, 64, 64, 64]          128
│    │    └─ReLU: 3-3                    [32, 64, 64, 64]          --
│    │    └─Conv2d: 3-4                  [32, 64, 64, 64]          36,864
│    │    └─BatchNorm2d: 3-5             [32, 64, 64, 64]          128
│    │    └─ReLU: 3-6                    [32, 64, 64, 64]          --
│ 

In [4]:
summary(model, input_size=(32, 3, 256, 256))

Layer (type:depth-idx)                        Output Shape              Param #
Unet                                          [32, 1, 256, 256]         --
├─ResNetEncoder: 1-1                          [32, 3, 256, 256]         --
│    └─Conv2d: 2-1                            [32, 64, 128, 128]        9,408
│    └─BatchNorm2d: 2-2                       [32, 64, 128, 128]        128
│    └─ReLU: 2-3                              [32, 64, 128, 128]        --
│    └─MaxPool2d: 2-4                         [32, 64, 64, 64]          --
│    └─Sequential: 2-5                        [32, 256, 64, 64]         --
│    │    └─Bottleneck: 3-1                   [32, 256, 64, 64]         75,008
│    │    └─Bottleneck: 3-2                   [32, 256, 64, 64]         70,400
│    │    └─Bottleneck: 3-3                   [32, 256, 64, 64]         70,400
│    └─Sequential: 2-6                        [32, 512, 32, 32]         --
│    │    └─Bottleneck: 3-4                   [32, 512, 32, 32]         379,392

Loading encoder's weights from TorchGeo

In [5]:
train_tiles_df = gpd.read_file(workdir / 'data/AREA_TRAIN_TEST/TrainTiles_CLS_UTM.geojson')
train_aoi_df = gpd.read_file(workdir / 'data/AREA_TRAIN_TEST/TrainSet_CLS_UTM.geojson')
test_tiles_df = gpd.read_file(workdir / 'data/AREA_TRAIN_TEST/TestTiles_CLS_UTM.geojson')
test_aoi_df = gpd.read_file(workdir / 'data/AREA_TRAIN_TEST/TestSet_CLS_UTM.geojson')

spring_tif_path = workdir / "data/GEE/three_seasons_median_2022/spring_march-april_median.tif"
spring_features_df = gpd.read_file(workdir / 'data/FEATURES/spring_V2.geojson')
summer_tif_path = workdir / "data/GEE/three_seasons_median_2022/summer_july-aug_median.tif"
summer_features_df = gpd.read_file(workdir / 'data/FEATURES/summer_V2.geojson')
winter_tif_path = workdir / "data/GEE/three_seasons_median_2022/winter_nov-dec_median.tif"
winter_features_df = gpd.read_file(workdir / 'data/FEATURES/winter_V2.geojson')



spring_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, spring_tif_path, spring_features_df, train_aoi_df)
summer_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, summer_tif_path, summer_features_df, train_aoi_df)
winter_train_dataset = SingleRasterPalaeochannelDataset(train_tiles_df, winter_tif_path, winter_features_df, train_aoi_df)

spring_test_dataset = SingleRasterPalaeochannelDataset(test_tiles_df, spring_tif_path, spring_features_df, test_aoi_df)
summer_test_dataset = SingleRasterPalaeochannelDataset(test_tiles_df, summer_tif_path, summer_features_df, test_aoi_df)
winter_test_dataset = SingleRasterPalaeochannelDataset(test_tiles_df, winter_tif_path, winter_features_df, test_aoi_df)

full_train_dataset = ConcatDataset([spring_train_dataset, summer_train_dataset, winter_train_dataset])
full_test_dataset = ConcatDataset([spring_test_dataset, summer_test_dataset, winter_test_dataset])
print(f"Datasets build! {len(full_train_dataset)} training tiles, {len(full_test_dataset)} testing tiles.")

Datasets build! 2472 training tiles, 255 testing tiles.


# LightningModule adapted from Andaleeb's notebook!

In [6]:
from typing import Sequence
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from lightning.pytorch.callbacks import StochasticWeightAveraging

class PalaeochannelRGBExperimentModule(LightningModule):
    def __init__(self, *args: Any, 
                 learning_rate: float = 1.0e-3, 
                 logits_threshold: float = 0.1, 
                 weight_decay: float = 1.0e-3, 
                 clip_stds: float = 2.5, 
                 swa_lrs = 1e-3,
                 swa_epoch_start = 20,
                 tversky_gamma: float = 1.0, 
                 tversky_alpha: float = 0.4, 
                 tversky_beta: float = 0.6,
                 model_tag: str = 'unet-resnet50-sen2-rgb-moco', # Consider a smarter tagging strategy.
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()
        
        # Model creation and weights loading
        if model_tag == 'unet-resnet50-sen2-rgb-moco':
            self.model = Unet(encoder_name="resnet50", in_channels=3)
            encoder_model = resnet50(weights=ResNet50_Weights.SENTINEL2_RGB_MOCO)
            self.model.encoder.load_state_dict(encoder_model.state_dict())
        else:
            self.model = Unet(encoder_name="mit_b5", decoder_attention_type='scse', in_channels=3)
        
        # TverskyLoss Loss function 
        # self.loss_criterion = smp.losses.TverskyLoss(mode="binary", 
        #                                              gamma=self.hparams.tversky_gamma, 
        #                                              alpha=self.hparams.tversky_alpha, 
        #                                              beta=self.hparams.tversky_beta,)
        # DiceLoss Loss function
        self.loss_criterion = smp.losses.DiceLoss(mode='binary')
        # Augmentations
        self.spatial_augmentation_pipeline = K.AugmentationSequential(
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            data_keys=["input", "mask"]  # Apply to both image and mask
        )
        self.color_augmentation_pipeline = K.AugmentationSequential(
            # K.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2, hue=0.1),
            data_keys=["input"]  # Apply to images only
        )
        # Metrics creation
        self.train_iou = BinaryJaccardIndex(threshold=self.hparams.logits_threshold)
        self.train_precision = BinaryPrecision(threshold=self.hparams.logits_threshold)
        self.train_recall = BinaryRecall(threshold=self.hparams.logits_threshold)
        
        self.validation_iou = BinaryJaccardIndex(threshold=self.hparams.logits_threshold)
        self.validation_precision = BinaryPrecision(threshold=self.hparams.logits_threshold)
        self.validation_recall = BinaryRecall(threshold=self.hparams.logits_threshold)
        
        self.validation_prec_rec_curve = BinaryPrecisionRecallCurve(thresholds=20)
        
        
        
    def configure_callbacks(self) -> Sequence[Callback] | Callback:
        swa = StochasticWeightAveraging(swa_lrs=self.hparams.swa_lrs, 
                                        swa_epoch_start=self.hparams.swa_epoch_start)
        return [swa]
    
    def configure_optimizers(self) -> OptimizerLRScheduler:
        return optim.Adam(self.model.parameters(),
                          lr=self.hparams.learning_rate, 
                          weight_decay=self.hparams.weight_decay)

    def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        # Fix batch types.
        images = batch['image'].float()
        masks = batch['mask'].float().unsqueeze(1)
        
        # Finish raster to tensor conversion
        images = torch.movedim(images, -1, -3)
        images = images[:, [3, 2, 1], :, :] # Select RGB bands
         
        # Clip each tile bands using local statistics.
        from torch.masked import masked_tensor
        masked_images = masked_tensor(images, images > 0.0)
        images_mean = masked_images.mean(dim=(2, 3), keepdim=True).get_data()
        images_std = masked_images.std(dim=(2, 3), keepdim=True).get_data()
        images_max_clip = images_mean + self.hparams.clip_stds * images_std
        images_min_clip = images_mean - self.hparams.clip_stds * images_std
        images = (images - images_min_clip) / (images_max_clip - images_min_clip)
        
        # Data augmentation during training.
        if self.trainer.training: 
            images = self.color_augmentation_pipeline(images)
            images, masks = self.spatial_augmentation_pipeline(images, masks)
        
        batch['image'] = images
        batch['mask'] = masks.squeeze(1).int()
        return super().on_after_batch_transfer(batch, dataloader_idx)
    
    def forward(self, batch):
        x = batch['image']
        # Do not propagate gradient to the encoder network for now.
        with torch.no_grad():
            features = self.model.encoder(x)
        decoder_output = self.model.decoder(*features)

        return self.model.segmentation_head(decoder_output)
    
    def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        logits = self.forward(batch).squeeze()
        loss = self.loss_criterion(logits, batch['mask'])
        
        # Update metrics
        self.train_iou(logits, batch['mask'])
        self.train_precision(logits, batch['mask'])
        self.train_recall(logits, batch['mask'])
        
        # Log step loss
        batch_size = batch['image'].shape[0]
        self.log('train/loss', loss, on_epoch=True, on_step=True, batch_size=batch_size)
        
        # Log batch output (images)
        if batch_idx == 0:
            self.log_batch_output(batch, logits)
            
        return loss
    
    def on_train_epoch_end(self) -> None:
        # Log metrics
        self.log('train/iou', self.train_iou)
        self.log('train/precision', self.train_precision)
        self.log('train/recall', self.train_recall)
        return super().on_train_epoch_end()
    
    def log_batch_output(self, batch, logits):
        if isinstance(self.logger, TensorBoardLogger):
            stage = 'none'
            if self.trainer.validating:
                stage = 'validation'
            if self.trainer.training:
                stage = 'train'
            mask = batch['mask'].unsqueeze(1)
            segmentation_mask = logits.unsqueeze(1) > self.hparams.logits_threshold
            summary_writer: SummaryWriter = self.logger.experiment
            summary_writer.add_images(f'{stage}/image', batch['image'], global_step=self.trainer.global_step)
            summary_writer.add_images(f'{stage}/mask', mask * 255, global_step=self.trainer.global_step)
            summary_writer.add_images(f'{stage}/logits', segmentation_mask, global_step=self.trainer.global_step)
            
            
    # The same thing as the training step but on validation objects.
    def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        logits = self.forward(batch).squeeze()
        print(logits.shape)
        print('batch[mask]',batch['mask'].shape)
        loss = self.loss_criterion(logits, batch['mask'])
        
        self.validation_iou(logits, batch['mask'])
        self.validation_precision(logits, batch['mask'])
        self.validation_recall(logits, batch['mask'])
        
        self.validation_prec_rec_curve(logits, batch['mask'])
        
        batch_size = batch['image'].shape[0]
        self.log('validation/loss', loss, on_epoch=True, batch_size=batch_size)
        if batch_idx == 0:
            self.log_batch_output(batch, logits)
        
        return loss
        
    def on_validation_epoch_end(self) -> None:
        self.log('validation/iou', self.validation_iou)
        self.log('validation/precision', self.validation_precision)
        self.log('validation/recall', self.validation_recall)
        
        # Log the precision recall curve!
        if isinstance(self.logger, TensorBoardLogger):
            summary_writer: SummaryWriter = self.logger.experiment
            fig_, ax_ = self.validation_prec_rec_curve.plot(score=True)
            summary_writer.add_figure('validation/prec_rec_curve', figure=fig_, global_step=self.trainer.global_step)
            
        return super().on_validation_epoch_end()


In [7]:

train_loader = DataLoader(dataset=full_train_dataset, 
                          batch_size=16, 
                          num_workers=16, 
                          prefetch_factor=16, 
                          pin_memory=True, 
                          persistent_workers=True, 
                          shuffle=True)
val_loader = DataLoader(dataset=full_test_dataset, 
                        batch_size=16, 
                        num_workers=16, 
                        prefetch_factor=16, 
                        pin_memory=True, 
                        persistent_workers=True, 
                        shuffle=True) # Consider shuffling the "full_test_dataset" in the future.

In [8]:
from lightning import Trainer
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint


lightning_module = PalaeochannelRGBExperimentModule(model_tag='mit_b5')
logs_dir = scratchdir / 'run_logs_test'
logger = TensorBoardLogger(name='dice-headonly-image_clip-imagenet-mit_b5-median', save_dir=logs_dir)
checkpointing = ModelCheckpoint(filename='epoch={epoch}-val_iou={validation/iou:.8f}', 
                                auto_insert_metric_name=False, 
                                monitor='validation/iou', 
                                mode='max', 
                                save_top_k=2, 
                                save_last=True)
trainer = Trainer(accelerator='gpu', devices=[0], log_every_n_steps=5, logger=logger, callbacks=[checkpointing])

Downloading: "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b5.pth" to /home/vscode/.cache/torch/hub/checkpoints/mit_b5.pth
100%|██████████| 313M/313M [00:07<00:00, 43.2MB/s] 
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader)

/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
Missing logger folder: ../run_logs_test/dice-headonly-image_clip-imagenet-mit_b5-median
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

   | Name                          | Type                       | Params
------------------------------------------------------------------------------
0  | model                         | Unet                       | 84.8 M
1  | loss_criterion                | DiceLoss                   | 0     
2  | spatial_augmentation_pipeline | AugmentationSequential     | 0     
3  | color_augmentation_pipeline   | AugmentationSequential     | 0     
4  | train_iou                     | BinaryJaccardIndex         | 0     
5  | train_precision               | BinaryPrecision            | 0     
6  | train_recall                  | BinaryRecall               | 0     
7  | validat

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



torch.Size([16, 256, 256])
batch[mask] torch.Size([16, 256, 256])
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:02<00:02,  0.37it/s]torch.Size([16, 256, 256])
batch[mask] torch.Size([16, 256, 256])
Epoch 0:   0%|          | 0/155 [00:00<?, ?it/s]                           

/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
