# Model Training Notebook

In [None]:
import sys
import os
sys.path.append(os.path.abspath('../src'))

In [None]:
# In a Jupyter notebook or IPython environment, run this in the first cell
%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

# Change Detection on LEVIR-CD 

In [None]:
from models import TinyCD, SiameseResNetUNet

"""model = TinyCD(
    bkbn_name="efficientnet_b4",
    pretrained=True,
    output_layer_bkbn="3",
    out_channels=2,
    freeze_backbone=False
)
model.to("cuda")"""

model = SiameseResNetUNet(
            in_channels=3,
            out_channels=2,
            backbone_name="resnet18",
            pretrained=True,
            freeze_backbone=False,
            mode="conc"
        )
model.to("cuda")

In [None]:
from datasets import Levir_cd_dataset

from training.augmentations import (
    get_val_augmentation_pipeline,
    get_train_augmentation_pipeline
    )

origin_dir = "../data/Levir-cd-256"
train_transform = get_train_augmentation_pipeline(image_size=None, 
                                          mean = None,
                                          std = None
                                          )

val_transform = get_val_augmentation_pipeline(image_size=None, 
                                          mean = None,
                                          std = None
                                          )
train_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=train_transform,
                              type = "train"
                              )
val_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "val"
                              )
test_data = Levir_cd_dataset(origin_dir=origin_dir, 
                              transform=val_transform,
                              type = "test"
                              )

In [None]:
# Define a Weighted Random Sampler 
from training.utils import define_weighted_random_sampler 

weighted_sampler, class_weights_dict = define_weighted_random_sampler(
        dataset=train_data, 
        mask_key="mask", 
        subset_size=200,
        seed=42
    )
print("Class Weights : ", class_weights_dict)

In [None]:
#class_weights = [class_weights_dict[i] for i in range(len(class_weights_dict))]
class_weights = [1.0, 20.0]

In [None]:
from torch.utils.data import DataLoader
train_dl = DataLoader(dataset=train_data, batch_size=32, pin_memory=True, num_workers=8) #sampler=weighted_sampler)
val_dl = DataLoader(dataset=val_data, shuffle=True, batch_size=32, pin_memory=True, num_workers=8)
test_dl = DataLoader(dataset=test_data, shuffle=False, batch_size=32, pin_memory=True, num_workers=8)

In [None]:
optimizer = optim.AdamW

In [None]:
from training import train, testing
from torch.utils.data import DataLoader
import torch.nn as nn
import torch 
from metrics import iou_score, f1_score, precision, recall
import torch.optim as optim
from losses import Ensemble, DiceLoss

mode = "multiclass"
nb_epochs = 3

criterion = Ensemble(
    list_losses=[
        torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights), reduction='mean').to("cuda"), 
        DiceLoss(mode=mode)
    ],
    weights=[0.7,0.3]
)
metrics = [f1_score, iou_score]

optimizer = optim.AdamW
params_opt = {"lr":1e-3, "weight_decay":1e-2, "amsgrad": False}

scheduler = optim.lr_scheduler.CosineAnnealingLR
params_sc = {"T_max" : 100}
early_stopping_params = {"patience": 5, "trigger_times": 0}

train(
    model = model,
    train_dl = train_dl,
    valid_dl = val_dl,
    test_dl = test_dl, 
    loss_fn = criterion,
    optimizer = optimizer, 
    scheduler = scheduler, 
    params_opt=params_opt,
    params_sc=params_sc,
    metrics = metrics,
    nb_epochs = nb_epochs,
    experiment_name = "Levir_CD_Siamese_ResNet18_Unet",
    log_dir="../runs",
    model_dir="../models",
    resume_path=None,
    early_stopping_params = early_stopping_params,
    image_key = "post_image",
    mask_key = "mask",
    num_classes = len(class_weights), 
    verbose = False,  # Adding verbose flag
    checkpoint_interval = 10,  # Add checkpoint interval parameter
    debug = False,  # Add debug flag for memory logging, 
    training_log_interval = 5, 
    is_mixed_precision=True,
    reduction= "weighted",
    class_weights = class_weights,
    class_names=["No Change", "Change"], 
    siamese=True,
    tta=False
)


In [None]:
from metrics import compute_model_class_performance

compute_model_class_performance(
    model=model,
    dataloader=test_dl,
    num_classes=2,
    device='cuda',
    class_names=["No Change", "Change"], 
    siamese=True,
    image_key="image",
    mask_key="mask",
    average_mode="macro",
    output_file="../outputs/Levir_CD_Siamese_ResNet34_Unet_with_TTA.txt",
    tta=True
)
compute_model_class_performance(
    model=model,
    dataloader=test_dl,
    num_classes=2,
    device='cuda',
    class_names=["No Change", "Change"], 
    siamese=True,
    image_key="image",
    mask_key="mask",
    average_mode="macro",
    output_file="../outputs/Levir_CD_Siamese_ResNet34_Unet_without_TTA.txt",
    tta=False
)