In [None]:
from src.dataset_coco import DatasetCOCOPanoptic
from src.model_backbone import DinoBackbone
from src.model_head import ASPPDecoder
from src.loss_focal_dice import SemanticLoss
import config.config as cfg
from src.common import tensor_to_image
from src.utils import visualize_maps, outputs_to_maps

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
import json
import sys
import copy
import os
import time

%matplotlib inline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

################ LOAD ALL THE PARAMETERS #############################
# DATASET PARAMETERS
COCO_ROOT = cfg.COCO_ROOT
IMG_SIZE = cfg.IMG_SIZE
PATCH_SIZE = cfg.PATCH_SIZE
PROB_AUGMENT_TRAINING = cfg.PROB_AUGMENT_TRAINING
PROB_AUGMENT_VALID = cfg.PROB_AUGMENT_VALID
IMG_MEAN = cfg.IMG_MEAN
IMG_STD = cfg.IMG_STD

# MODEL PARAMETERS
DINOV3_DIR = cfg.DINOV3_DIR
DINO_MODEL = cfg.DINO_MODEL
DINO_WEIGHTS = cfg.DINO_WEIGHTS
MODEL_TO_NUM_LAYERS = cfg.MODEL_TO_NUM_LAYERS
MODEL_TO_EMBED_DIM = cfg.MODEL_TO_EMBED_DIM
HIDDEN_DIM = cfg.HIDDEN_DIM
TARGET_SIZE = cfg.TARGET_SIZE

# TRAINING PARAMETERS
BATCH_SIZE = cfg.BATCH_SIZE
WEIGHT_LOSS_DICE = cfg.WEIGHT_LOSS_DICE
WEIGHT_LOSS_FOCAL = cfg.WEIGHT_LOSS_FOCAL

LEARNING_RATE = cfg.LEARNING_RATE
WEIGHT_DECAY = cfg.WEIGHT_DECAY
NUM_EPOCHS = cfg.NUM_EPOCHS
NUM_SAMPLES_PLOT = cfg.NUM_SAMPLES_PLOT

LOAD_MODEL = cfg.LOAD_MODEL
SAVE_MODEL = cfg.SAVE_MODEL
MODEL_PATH_TRAIN_LOAD = cfg.MODEL_PATH_TRAIN_LOAD
RESULTS_PATH = cfg.RESULTS_PATH

train_set = DatasetCOCOPanoptic(COCO_ROOT, "train", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_TRAINING, IMG_MEAN, IMG_STD)
train_dataloader = DataLoader(train_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True)

val_set = DatasetCOCOPanoptic(COCO_ROOT, "val", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_VALID, IMG_MEAN, IMG_STD)
val_dataloader = DataLoader(val_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True)

num_classes = len(train_set.class_names)

dino_model = torch.hub.load(
        repo_or_dir=DINOV3_DIR,
        model=DINO_MODEL,
        source="local",
        weights=DINO_WEIGHTS
)
n_layers_dino = MODEL_TO_NUM_LAYERS[DINO_MODEL]
embed_dim = MODEL_TO_EMBED_DIM[DINO_MODEL]

dino_backbone = DinoBackbone(dino_model, n_layers_dino).to(device)

model_head = ASPPDecoder(num_classes=len(train_set.class_names), in_ch=embed_dim,
                                target_size=(TARGET_SIZE, TARGET_SIZE)).to(device)

optimizer = optim.Adam(model_head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Load model
if LOAD_MODEL:
    model_head.load_state_dict(torch.load(MODEL_PATH_TRAIN_LOAD))
    print("Model successfully loaded!")

# Freeze parameters
for p in dino_backbone.parameters():
    p.requires_grad = False

n_params = sum([p.numel() for p in dino_backbone.parameters()]) + sum([p.numel() for p in model_head.parameters()])
print("Total number of parameters: ", n_params)
n_trainable_params = sum([p.numel() for p in dino_backbone.parameters() if p.requires_grad]) + sum([p.numel() for p in model_head.parameters() if p.requires_grad])
print("Total number of trainable parameters: ", n_trainable_params)
n_params_backbone = sum([p.numel() for p in dino_backbone.parameters()])
print("Number parameters backbone: ", n_params_backbone)

# Prepare loss fcn
# Focal + dice
loss_module = SemanticLoss(dice_weight=WEIGHT_LOSS_DICE, focal_weight=WEIGHT_LOSS_FOCAL, target_size=(TARGET_SIZE, TARGET_SIZE), ignore_index=None)

if SAVE_MODEL:
    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_path = f"{RESULTS_PATH}/{current_date}"
    
    json_params = { 
        "IMG_SIZE" : IMG_SIZE, 
        "PATCH_SIZE" : PATCH_SIZE, 
        "PROB_AUGMENT_TRAINING": PROB_AUGMENT_TRAINING,
        "PROB_AUGMENT_VALID": PROB_AUGMENT_VALID,
        "DINO_MODEL": DINO_MODEL,
        "HIDDEN_DIM": HIDDEN_DIM,
        "TARGET_SIZE": TARGET_SIZE,
        "WEIGHT_LOSS_DICE": WEIGHT_LOSS_DICE,
        "WEIGHT_LOSS_FOCAL": WEIGHT_LOSS_FOCAL,
        "LEARNING_RATE" : LEARNING_RATE,
        "LOAD_MODEL" : LOAD_MODEL,
        "MODEL_PATH_TRAIN_LOAD" : MODEL_PATH_TRAIN_LOAD,

    }

In [None]:
# DINO backbone always in eval mode
dino_backbone.eval()

for epoch in range(NUM_EPOCHS):
    ##################### TRAIN #######################
    model_head.train()
    train_loss = 0.0
    
    for batch_idx, (images, semantic_targets, _) in enumerate(tqdm(train_dataloader)):
        images = images.to(device, dtype=torch.float)
        semantic_targets = semantic_targets.to(device, dtype=torch.long)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        feat = dino_backbone(images)
        semantic_logits = model_head(feat)
        
        # Calculate loss
        losses = loss_module(semantic_logits, semantic_targets)
        
        # Backward pass
        total_loss = losses[0]
        total_loss.backward()
        
        # Gradient clipping and Optimize
        # clip all gradients to max norm 5.0
        torch.nn.utils.clip_grad_norm_(model_head.parameters(), max_norm=5.0)
        optimizer.step()
        
        train_loss += total_loss.item()

        # Show loss
        if (batch_idx % 400 == 0 and batch_idx > 0):
            print(f"Epoch {epoch+1}, batch {batch_idx}, Loss: {train_loss/(batch_idx+1)}, focal loss: {losses[1]}, dice loss: {losses[2]}")

        # Plot samples
        if (batch_idx % 4000 == 0 and batch_idx > 0):
            with torch.no_grad():
                for i in range(min(NUM_SAMPLES_PLOT, BATCH_SIZE)):
                    visualize_maps(tensor_to_image(images[i], IMG_MEAN, IMG_STD), semantic_targets[i].detach().cpu().numpy(), 
                                   class_names=train_set.class_names,
                                   alpha=1.0,
                                   figsize=(12, 8),
                                   draw_semantic_labels=True,
                                   semantic_label_fontsize=5,
                                   seed=42)
                    semantic_mask_pred = outputs_to_maps(semantic_logits[i],
                                                        images.shape[2:],
                                                        )
                    visualize_maps(tensor_to_image(images[i], IMG_MEAN, IMG_STD), semantic_mask_pred, 
                                   class_names=train_set.class_names,
                                   alpha=1.0,
                                   figsize=(12, 8),
                                   draw_semantic_labels=True,
                                   semantic_label_fontsize=5,
                                   seed=42)        

    train_loss /= float(batch_idx+1)

    ##################### VALIDATION #######################
    model_head.eval()
    val_loss = 0.0
    val_loss_dice = 0.0
    val_loss_focal = 0.0

    with torch.no_grad():
        for batch_idx, (images, semantic_targets, _) in enumerate(tqdm(val_dataloader)):
            images = images.to(device, dtype=torch.float)
            semantic_targets = semantic_targets.to(device, dtype=torch.long)
            
            # Forward pass
            feat = dino_backbone(images)
            semantic_logits = model_head(feat)
            
            # Calculate loss
            losses = loss_module(semantic_logits, semantic_targets)
            
            # Backward pass
            total_loss = losses[0]
    
            val_loss += total_loss.item()
            val_loss_focal += losses[1].item()
            val_loss_dice += losses[2].item()

            # Plot samples
            if batch_idx == 0:
                for i in range(min(NUM_SAMPLES_PLOT, BATCH_SIZE)):
                    visualize_maps(tensor_to_image(images[i], IMG_MEAN, IMG_STD), semantic_targets[i].detach().cpu().numpy(), 
                                   class_names=train_set.class_names,
                                   alpha=1.0,
                                   figsize=(12, 8),
                                   draw_semantic_labels=True,
                                   semantic_label_fontsize=5,
                                   seed=42)
                    semantic_mask_pred = outputs_to_maps(semantic_logits[i],
                                                        images.shape[2:],
                                                        )
                    visualize_maps(tensor_to_image(images[i], IMG_MEAN, IMG_STD), semantic_mask_pred, 
                                   class_names=train_set.class_names,
                                   alpha=1.0,
                                   figsize=(12, 8),
                                   draw_semantic_labels=True,
                                   semantic_label_fontsize=5,
                                   seed=42)  


        val_loss /= float(batch_idx+1)
        val_loss_focal /= float(batch_idx+1)
        val_loss_dice /= float(batch_idx+1)
    
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss}, val loss NN total: {val_loss}, val loss focal: {val_loss_focal},  val loss dice: {val_loss_dice}")
    
    if SAVE_MODEL:
        os.makedirs(folder_path, exist_ok=True)
        # Save model and params
        json_params_epoch = json_params.copy()
        json_params_epoch["epoch"] = epoch
        json_params_epoch["train_loss"] = train_loss
        json_params_epoch["val_loss"] = val_loss
        json_params_epoch["val_loss_focal"] = val_loss_focal
        json_params_epoch["val_loss_dice"] = val_loss_dice
        model_path = os.path.join(folder_path,f"model_{epoch}.pth")
        json_path = os.path.join(folder_path,f"params_{epoch}.json")
        torch.save(model_head.state_dict(), model_path)
        with open(json_path, "w") as outfile:
            json.dump(json_params_epoch, outfile)
    
print("Training finished.")