# Model Training Notebook

In [6]:
import sys
import os
sys.path.append(os.path.abspath('../src'))

In [7]:
# In a Jupyter notebook or IPython environment, run this in the first cell
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
%reload_ext autoreload

### Change Detection on LEVIR-CD 

In [9]:
from models import TinyCD, SiameseResNetUNet

"""model = TinyCD(
    bkbn_name="efficientnet_b4",
    pretrained=True,
    output_layer_bkbn="3",
    out_channels=2,
    freeze_backbone=False
)
model.to("cuda")"""

model = SiameseResNetUNet(
            in_channels=3,
            out_channels=2,
            backbone_name="resnet18",
            pretrained=True,
            freeze_backbone=False,
            mode="conc"
        )
model.to("cuda")

SiameseResNetUNet(
  (firstconv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (firstbn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (firstrelu): ReLU(inplace=True)
  (firstmaxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (encoder1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

In [13]:
from datasets import Levir_cd_dataset

from training.augmentations import (
    get_val_augmentation_pipeline,
    get_train_augmentation_pipeline
    )

origin_dir = "../data/Levir-cd-256"
train_transform = get_train_augmentation_pipeline(image_size=None, 
                                          mean = None,
                                          std = None
                                          )

val_transform = get_val_augmentation_pipeline(image_size=None, 
                                          mean = None,
                                          std = None
                                          )
train_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=train_transform,
                              type = "train"
                              )
val_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "val"
                              )
test_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "test"
                              )

  original_init(self, **validated_kwargs)
  Expected `dict[str, any]` but got `UniformParams` with value `UniformParams(noise_type=...6, 0.0784313725490196)])` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Loaded 7120 train samples.
Loaded 1024 val samples.
Loaded 2048 test samples.


In [8]:
### Define a Weighted Random Sampler 
from training.utils import define_weighted_random_sampler 

weighted_sampler, class_weights_dict = define_weighted_random_sampler(
        dataset=train_data, 
        mask_key="mask", 
        subset_size=200,
        seed=42
    )
print("Class Weights : ", class_weights_dict)

Counting class frequencies: 100%|██████████| 200/200 [00:04<00:00, 41.19it/s]
Assigning sample weights: 7120it [01:35, 74.94it/s]

Class Weights :  {np.int32(0): 1.0582325839752058, np.int32(1): 18.172516342789166}





In [14]:
#class_weights = [class_weights_dict[i] for i in range(len(class_weights_dict))]
class_weights = [1.0, 20.0]

In [15]:
from torch.utils.data import DataLoader
train_dl = DataLoader(dataset=train_data, batch_size=32, pin_memory=True, num_workers=8) #sampler=weighted_sampler)
val_dl = DataLoader(dataset=val_data, shuffle=True, batch_size=32, pin_memory=True, num_workers=8)
test_dl = DataLoader(dataset=test_data, shuffle=False, batch_size=32, pin_memory=True, num_workers=8)

In [21]:
optimizer = optim.AdamW

In [36]:
from training import train, testing
from torch.utils.data import DataLoader
import torch.nn as nn
import torch 
from metrics import iou_score, f1_score, precision, recall
import torch.optim as optim
from losses import Ensemble, DiceLoss

mode = "multiclass"
nb_epochs = 3

criterion = Ensemble(
    list_losses=[
        torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights), reduction='mean').to("cuda"), 
        DiceLoss(mode=mode)
    ],
    weights=[0.7,0.3]
)
metrics = [f1_score, iou_score]

optimizer = optim.AdamW
params_opt = {"lr":1e-3, "weight_decay":1e-2, "amsgrad": False}

scheduler = optim.lr_scheduler.CosineAnnealingLR
params_sc = {"T_max" : 100}
early_stopping_params = {"patience": 5, "trigger_times": 0}

train(
    model = model,
    train_dl = train_dl,
    valid_dl = val_dl,
    test_dl = test_dl, 
    loss_fn = criterion,
    optimizer = optimizer, 
    scheduler = scheduler, 
    params_opt=params_opt,
    params_sc=params_sc,
    metrics = metrics,
    nb_epochs = nb_epochs,
    experiment_name = "Levir_CD_Siamese_ResNet18_Unet",
    log_dir="../runs",
    model_dir="../models",
    resume_path=None,
    early_stopping_params = early_stopping_params,
    image_key = "post_image",
    mask_key = "mask",
    num_classes = len(class_weights), 
    verbose = False,  # Adding verbose flag
    checkpoint_interval = 10,  # Add checkpoint interval parameter
    debug = False,  # Add debug flag for memory logging, 
    training_log_interval = 5, 
    is_mixed_precision=True,
    reduction= "weighted",
    class_weights = class_weights,
    class_names=["No Change", "Change"], 
    siamese=True,
    tta=False
)


INFO:root:Experiment logs are recorded at ../runs/Levir_CD_Siamese_ResNet18_Unet
INFO:root:Model Signature has been defined
2025/01/12 18:32:55 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
INFO:root:Hyperparameters have been logged
Testing: 100%|██████████| 64/64 [00:23<00:00,  2.75batch/s]
INFO:root:Per-Class Performance and Overall Performance file logged as artifact
2025/01/12 18:33:53 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/01/12 18:33:53 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


🏃 View run Levir_CD_Siamese_ResNet18_Unet_20250112-183253 at: http://localhost:5000/#/experiments/205686240822932636/runs/f154aa45f491457080d09865b8ddfe31
🧪 View experiment at: http://localhost:5000/#/experiments/205686240822932636


In [30]:
from metrics import compute_model_class_performance

compute_model_class_performance(
    model=model,
    dataloader=test_dl,
    num_classes=2,
    device='cuda',
    class_names=["No Change", "Change"], 
    siamese=True,
    image_key="image",
    mask_key="mask",
    average_mode="macro",
    output_file="../outputs/Levir_CD_Siamese_ResNet34_Unet_with_TTA.txt",
    tta=True
)
compute_model_class_performance(
    model=model,
    dataloader=test_dl,
    num_classes=2,
    device='cuda',
    class_names=["No Change", "Change"], 
    siamese=True,
    image_key="image",
    mask_key="mask",
    average_mode="macro",
    output_file="../outputs/Levir_CD_Siamese_ResNet34_Unet_without_TTA.txt",
    tta=False
)

Testing:   6%|▋         | 4/64 [00:07<01:41,  1.69s/batch]

Testing: 100%|██████████| 64/64 [01:40<00:00,  1.57s/batch]


Per-Class Performance Metrics (TTA):
+-----------+-----------+--------+----------+--------+--------+
|   Class   | Precision | Recall | F1 Score |  IoU   |  Dice  |
+-----------+-----------+--------+----------+--------+--------+
| No Change |   0.9985  | 0.9796 |  0.9889  | 0.9781 | 0.9889 |
|   Change  |   0.7187  | 0.9720 |  0.8263  | 0.7041 | 0.8263 |
+-----------+-----------+--------+----------+--------+--------+
----------------------------------------
Overall Performance Metrics:
  Precision (macro): 0.8586
  Recall (macro):    0.9758
  F1 Score (macro):  0.9076
Metrics have been saved to ../outputs/Levir_CD_Siamese_ResNet34_Unet_with_TTA.txt


Testing: 100%|██████████| 64/64 [00:29<00:00,  2.20batch/s]


Per-Class Performance Metrics ():
+-----------+-----------+--------+----------+--------+--------+
|   Class   | Precision | Recall | F1 Score |  IoU   |  Dice  |
+-----------+-----------+--------+----------+--------+--------+
| No Change |   0.9984  | 0.9773 |  0.9877  | 0.9758 | 0.9877 |
|   Change  |   0.6966  | 0.9710 |  0.8113  | 0.6824 | 0.8113 |
+-----------+-----------+--------+----------+--------+--------+
----------------------------------------
Overall Performance Metrics:
  Precision (macro): 0.8475
  Recall (macro):    0.9742
  F1 Score (macro):  0.8995
Metrics have been saved to ../outputs/Levir_CD_Siamese_ResNet34_Unet_without_TTA.txt
