# Model Training Notebook

In [15]:
import sys
import os
sys.path.append(os.path.abspath('../src'))

In [2]:
# In a Jupyter notebook or IPython environment, run this in the first cell
%load_ext autoreload
%autoreload 2

In [3]:
%reload_ext autoreload

### Change Detection on LEVIR-CD 

In [16]:
from models import TinyCD, SiameseResNetUNet

"""model = TinyCD(
    bkbn_name="efficientnet_b4",
    pretrained=True,
    output_layer_bkbn="3",
    out_channels=2,
    freeze_backbone=False
)
model.to("cuda")"""

model = SiameseResNetUNet(
            in_channels=3,
            out_channels=2,
            backbone_name="resnet34",
            pretrained=True,
            freeze_backbone=False,
            mode="conc"
        )
model.to("cuda")

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /home/onyxia/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 96.7MB/s]


SiameseResNetUNet(
  (firstconv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (firstbn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (firstrelu): ReLU(inplace=True)
  (firstmaxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (encoder1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

In [17]:
from datasets import Levir_cd_dataset

from training.augmentations import (
    get_val_augmentation_pipeline,
    get_train_augmentation_pipeline
    )

origin_dir = "../data/Levir-cd-256"
train_transform = get_train_augmentation_pipeline(image_size=(256, 256), 
                                          mean = None,
                                          std = None
                                          )

val_transform = get_val_augmentation_pipeline(image_size=(256, 256), 
                                          mean = None,
                                          std = None
                                          )
train_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=train_transform,
                              type = "train"
                              )
val_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "val"
                              )
test_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "test"
                              )

  Expected `dict[str, any]` but got `UniformParams` with value `UniformParams(noise_type=...6, 0.0784313725490196)])` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Loaded 7120 train samples.
Loaded 1024 val samples.
Loaded 2048 test samples.


In [18]:
### Define a Weighted Random Sampler 
from training.utils import define_weighted_random_sampler 

weighted_sampler, class_weights_dict = define_weighted_random_sampler(
        dataset=train_data, 
        mask_key="mask", 
        subset_size=200,
        seed=42
    )
print("Class Weights : ", class_weights_dict)

Counting class frequencies: 100%|██████████| 200/200 [00:05<00:00, 33.96it/s]
Assigning sample weights: 7120it [02:04, 57.20it/s]

Class Weights :  {np.int32(1): 17.380557728750507, np.int32(0): 1.061047982404288}





In [20]:
class_weights = [class_weights_dict[i] for i in range(len(class_weights_dict))]

In [21]:
from torch.utils.data import DataLoader
train_dl = DataLoader(dataset=train_data, batch_size=32, pin_memory=True, num_workers=8, sampler=weighted_sampler)
val_dl = DataLoader(dataset=val_data, shuffle=True, batch_size=32, pin_memory=True, num_workers=8)
test_dl = DataLoader(dataset=test_data, shuffle=False, batch_size=32, pin_memory=True, num_workers=8)

In [None]:
from training import train, testing
from torch.utils.data import DataLoader
import torch.nn as nn
import torch 
from metrics import iou_score, f1_score, precision, recall
from losses import DiceLoss, Ensemble, FocalLoss
import torch.optim as optim
from training.utils import define_weighted_random_sampler

mode = "multiclass"
nb_epochs = 60
criterion =  torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights), reduction='mean').to("cuda")
metrics = [f1_score, iou_score]

optimizer = optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-2, amsgrad=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
early_stopping_params = {"patience": 10, "trigger_times": 0}

train(
    model = model,
    train_dl = train_dl,
    valid_dl = val_dl,
    loss_fn = criterion,
    optimizer = optimizer, 
    scheduler = scheduler, 
    metrics = metrics,
    nb_epochs = nb_epochs,
    experiment_name = "Levir_CD_Siamese_ResNet18_Unet",
    log_dir="../runs",
    model_dir="../models",
    resume_path=None,
    early_stopping_params = early_stopping_params,
    image_key = "post_image",
    mask_key = "mask",
    num_classes = len(class_weights), 
    verbose = False,  # Adding verbose flag
    checkpoint_interval = 10,  # Add checkpoint interval parameter
    debug = False,  # Add debug flag for memory logging, 
    training_log_interval = 5, 
    is_mixed_precision=True,
    reduction= "weighted",
    class_weights = class_weights,
    siamese=True,
)

INFO:root:Experiment logs are recorded at ../runs/Levir_CD_Siamese_ResNet18_Unet_20250106-215644
INFO:root:Model graph has been logged
INFO:root:Hyperparameters have been logged
INFO:root:Epoch 1
INFO:root:--------------------
Epoch 1:   0%|          | 0/223 [00:00<?, ?batch/s]

In [14]:
from metrics import compute_model_class_performance

compute_model_class_performance(
    model=model,
    dataloader=test_dl,
    num_classes=2,
    device='cuda',
    class_names=["No Change", "Change"], 
    siamese=True,
    image_key="image",
    mask_key="mask",
    average_mode="macro",
    output_file="../outputs/Levir_CD_Siamese_ResNet18_Unet_with_TTA.txt",
    tta=True
)

Testing: 100%|██████████| 64/64 [01:48<00:00,  1.69s/batch]


Per-Class Performance Metrics (TTA):
+-----------+-----------+--------+----------+--------+--------+
|   Class   | Precision | Recall | F1 Score |  IoU   |  Dice  |
+-----------+-----------+--------+----------+--------+--------+
| No Change |   0.9983  | 0.9823 |  0.9902  | 0.9807 | 0.9902 |
|   Change  |   0.7463  | 0.9687 |  0.8431  | 0.7287 | 0.8431 |
+-----------+-----------+--------+----------+--------+--------+
----------------------------------------
Overall Performance Metrics:
  Precision (macro): 0.8723
  Recall (macro):    0.9755
  F1 Score (macro):  0.9167
Metrics have been saved to ../outputs/Levir_CD_Siamese_ResNet18_Unet_with_TTA.txt
