In [None]:
from src.dataset_coco import DatasetCOCO
from src.model_backbone import DinoBackbone
from src.model_head import DinoFCOSHead
from src.loss import compute_loss
import config.config as cfg
from src.common import tensor_to_image
from src.utils import decode_outputs, plot_detections
from collate_fn import collate_fn

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
import json
import sys
import copy
import os
import time

%matplotlib inline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

################ LOAD ALL THE PARAMETERS #############################
# DATASET PARAMETERS
COCO_ROOT = cfg.COCO_ROOT # Root to the folder with the prepared data
IMG_SIZE = cfg.IMG_SIZE
PATCH_SIZE = cfg.PATCH_SIZE
PROB_AUGMENT_TRAINING = cfg.PROB_AUGMENT_TRAINING
PROB_AUGMENT_VALID = cfg.PROB_AUGMENT_VALID
IMG_MEAN = cfg.IMG_MEAN
IMG_STD = cfg.IMG_STD
    
# MODEL PARAMETERS
DINOV3_DIR = cfg.DINOV3_DIR
FPN_CH = cfg.FPN_CH
DINO_MODEL = cfg.DINO_MODEL
DINO_WEIGHTS = cfg.DINO_WEIGHTS
MODEL_TO_NUM_LAYERS = cfg.MODEL_TO_NUM_LAYERS
MODEL_TO_EMBED_DIM = cfg.MODEL_TO_EMBED_DIM
N_LAYERS_UNFREEZE = cfg.N_LAYERS_UNFREEZE
N_CONVS = cfg.N_CONVS

# TRAINING PARAMETERS
BATCH_SIZE = cfg.BATCH_SIZE
FOCAL_ALPHA = cfg.FOCAL_ALPHA
FOCAL_GAMMA = cfg.FOCAL_GAMMA
WEIGHT_REG = cfg.WEIGHT_REG
WEIGHT_CTR = cfg.WEIGHT_CTR

LEARNING_RATE = cfg.LEARNING_RATE
WEIGHT_DECAY = cfg.WEIGHT_DECAY
NUM_EPOCHS = cfg.NUM_EPOCHS
NUM_SAMPLES_PLOT = cfg.NUM_SAMPLES_PLOT

LOAD_MODEL = cfg.LOAD_MODEL
SAVE_MODEL = cfg.SAVE_MODEL
MODEL_PATH_TRAIN_LOAD = cfg.MODEL_PATH_TRAIN_LOAD
RESULTS_PATH = cfg.RESULTS_PATH

# INFERENCE PARAMETERS (FOR PLOTTING)
SCORE_THRESH = cfg.SCORE_THRESH
NMS_THRESH = cfg.NMS_THRESH

train_set = DatasetCOCO(COCO_ROOT, "train", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_TRAINING, IMG_MEAN, IMG_STD)
train_dataloader = DataLoader(train_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True, collate_fn=collate_fn)

val_set = DatasetCOCO(COCO_ROOT, "val", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_VALID, IMG_MEAN, IMG_STD)
val_dataloader = DataLoader(val_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True, collate_fn=collate_fn)

num_classes = len(train_set.class_names)

dino_model = torch.hub.load(
        repo_or_dir=DINOV3_DIR,
        model=DINO_MODEL,
        source="local",
        weights=DINO_WEIGHTS
)
n_layers_dino = MODEL_TO_NUM_LAYERS[DINO_MODEL]
embed_dim = MODEL_TO_EMBED_DIM[DINO_MODEL]

dino_backbone = DinoBackbone(dino_model, n_layers_dino).to(device)
model_head = DinoFCOSHead(backbone_out_channels=embed_dim, fpn_channels=FPN_CH, num_classes=num_classes, num_convs=N_CONVS).to(device)

optimizer = optim.Adam(model_head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Load model
if LOAD_MODEL:
    model_head.load_state_dict(torch.load(MODEL_PATH_TRAIN_LOAD))
    print("Model successfully loaded!")


# Freeze parameters
for p in dino_backbone.parameters():
    p.requires_grad = False

# Unfreeze last N transformer blocks
if N_LAYERS_UNFREEZE > 0:
    # Unfreeze the last norm layer
    for p in dino_backbone.dino.norm.parameters():
        p.requires_grad = True
    for block in dino_backbone.dino.blocks[-N_LAYERS_UNFREEZE:]:
        for param in block.parameters():
            param.requires_grad = True

n_params = sum([p.numel() for p in dino_backbone.parameters()]) + sum([p.numel() for p in model_head.parameters()])
print("Total number of parameters: ", n_params)
n_trainable_params = sum([p.numel() for p in dino_backbone.parameters() if p.requires_grad]) + sum([p.numel() for p in model_head.parameters() if p.requires_grad])
print("Total number of trainable parameters: ", n_trainable_params)
n_params_backbone = sum([p.numel() for p in dino_backbone.parameters()])
print("Number parameters backbone: ", n_params_backbone)

if SAVE_MODEL:
    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_path = f"{RESULTS_PATH}/{current_date}"
    
    json_params = { 
        "IMG_SIZE" : IMG_SIZE, 
        "PATCH_SIZE" : PATCH_SIZE, 
        "PROB_AUGMENT_TRAINING": PROB_AUGMENT_TRAINING,
        "PROB_AUGMENT_VALID": PROB_AUGMENT_VALID,
        "DINO_MODEL": DINO_MODEL,
        "N_LAYERS_UNFREEZE": N_LAYERS_UNFREEZE,
        "N_CONVS": N_CONVS,
        "FPN_CH": FPN_CH,
        "FOCAL_ALPHA" : FOCAL_ALPHA,
        "FOCAL_GAMMA" : FOCAL_GAMMA,
        "WEIGHT_REG" : WEIGHT_REG,
        "WEIGHT_CTR": WEIGHT_CTR,
        "LEARNING_RATE" : LEARNING_RATE,
        "LOAD_MODEL" : LOAD_MODEL,
        "MODEL_PATH_TRAIN_LOAD" : MODEL_PATH_TRAIN_LOAD,

    }

In [None]:
# DINO backbone always in eval mode
dino_backbone.eval()

for epoch in range(NUM_EPOCHS):
    ##################### TRAIN #######################
    model_head.train()
    train_loss = 0.0
    
    for batch_idx, (images, boxes, labels) in enumerate(tqdm(train_dataloader)):
        images = images.to(device, dtype=torch.float) #, boxes.to(device, dtype=torch.float), labels.to(device, dtype=torch.int)
        boxes = [box.to(device, dtype=torch.float) for box in boxes]
        labels = [label.to(device, dtype=torch.int) for label in labels]
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        feat = dino_backbone(images)
        outputs = model_head(feat)

        first_stride = IMG_SIZE / outputs['cls'][0].shape[2]
        strides = [first_stride]
        for l in range(1,len(outputs['cls'])):
            strides.append(first_stride*2**l)
        
        # Calculate loss
        loss = compute_loss(outputs, boxes, labels, images.shape[2:], strides, FOCAL_ALPHA, FOCAL_GAMMA, WEIGHT_REG, WEIGHT_CTR, 
                           scale_ranges = None)

        
        # Backward pass
        loss[0].backward()
        
        # Gradient clipping and Optimize
        # clip all gradients to max norm 5.0
        torch.nn.utils.clip_grad_norm_(model_head.parameters(), max_norm=5.0)
        optimizer.step()
        
        train_loss += loss[0].item()

        # Show error
        if (batch_idx % 200 == 0 and batch_idx > 0):
            print(f"Epoch {epoch+1}, batch {batch_idx}, Loss: {train_loss/(batch_idx+1)}, cls loss: {loss[1].item()}, regression loss: {loss[2].item()}, ctr loss: {loss[3].item()}")

        # Plot
        if (batch_idx % 2000 == 0 and batch_idx > 0):
            for i in range(NUM_SAMPLES_PLOT):
                sample_cls = []
                sample_reg = []
                sample_ctr = []
                for j in range(len(outputs['cls'])):
                    sample_cls.append(outputs["cls"][j][i].unsqueeze(0))
                    sample_reg.append(outputs["reg"][j][i].unsqueeze(0))
                    sample_ctr.append(outputs["ctr"][j][i].unsqueeze(0))
                output_sample = {}
                output_sample["cls"] = sample_cls
                output_sample["reg"] = sample_reg
                output_sample["ctr"] = sample_ctr
                boxes_plot, scores_plot, labels_plot = decode_outputs(output_sample, images[i].shape[1:], strides, score_thresh=SCORE_THRESH, nms_thresh=NMS_THRESH)
                # Now we create targets
                box_target = torch.zeros_like(boxes[i])
                if box_target.ndim>1:
                    box_target[:, 0] = boxes[i][:, 0] * images[i].shape[2]
                    box_target[:, 1] = boxes[i][:, 1] * images[i].shape[1]
                    box_target[:, 2] = box_target[:, 0] + boxes[i][:, 2] * images[i].shape[2]
                    box_target[:, 3] = box_target[:, 1] + boxes[i][:, 3] * images[i].shape[1]
                    plot_detections(tensor_to_image(images[i], IMG_MEAN, IMG_STD), box_target.cpu(), torch.ones_like(labels[i]).cpu(), labels[i].cpu(), train_set.class_names, figsize=(5,5))
                plot_detections(tensor_to_image(images[i], IMG_MEAN, IMG_STD), boxes_plot.cpu(), scores_plot.cpu(), labels_plot.cpu(), val_set.class_names, figsize=(5,5))
    train_loss /= float(batch_idx+1)
    
    
    ##################### VALIDATION #######################
    model_head.eval()
    val_loss = 0.0
    val_loss_cls = 0.0
    val_loss_reg = 0.0
    val_loss_ctr = 0.0

    with torch.no_grad():
        for batch_idx, (images, boxes, labels) in enumerate(tqdm(val_dataloader)):
            images = images.to(device, dtype=torch.float) #, boxes.to(device, dtype=torch.float), labels.to(device, dtype=torch.int)
            boxes = [box.to(device, dtype=torch.float) for box in boxes]
            labels = [label.to(device, dtype=torch.int) for label in labels]
            
            # Forward pass
            feat = dino_backbone(images)
            outputs = model_head(feat)
    
            first_stride = IMG_SIZE / outputs['cls'][0].shape[2]
            strides = [first_stride]
            for l in range(1,len(outputs['cls'])):
                strides.append(first_stride*2**l)
            
            # Calculate loss
            loss = compute_loss(outputs, boxes, labels, images.shape[2:], strides, FOCAL_ALPHA, FOCAL_GAMMA, WEIGHT_REG, WEIGHT_CTR,
                               scale_ranges = None)
    
            val_loss += loss[0].item()
            val_loss_cls += loss[1].item()
            val_loss_reg += loss[2].item()
            val_loss_ctr += loss[3].item()

            # Plot
            if (batch_idx==0):
                for i in range(NUM_SAMPLES_PLOT):
                    sample_cls = []
                    sample_reg = []
                    sample_ctr = []
                    for j in range(len(outputs['cls'])):
                        sample_cls.append(outputs["cls"][j][i].unsqueeze(0))
                        sample_reg.append(outputs["reg"][j][i].unsqueeze(0))
                        sample_ctr.append(outputs["ctr"][j][i].unsqueeze(0))
                    output_sample = {}
                    output_sample["cls"] = sample_cls
                    output_sample["reg"] = sample_reg
                    output_sample["ctr"] = sample_ctr
                    boxes_plot, scores_plot, labels_plot = decode_outputs(output_sample, images[i].shape[1:], strides, score_thresh=SCORE_THRESH, nms_thresh=NMS_THRESH)
                    # Now we create targets
                    box_target = torch.zeros_like(boxes[i])
                    if box_target.ndim>1:
                        box_target[:, 0] = boxes[i][:, 0] * images[i].shape[2]
                        box_target[:, 1] = boxes[i][:, 1] * images[i].shape[1]
                        box_target[:, 2] = box_target[:, 0] + boxes[i][:, 2] * images[i].shape[2]
                        box_target[:, 3] = box_target[:, 1] + boxes[i][:, 3] * images[i].shape[1]
                        plot_detections(tensor_to_image(images[i], IMG_MEAN, IMG_STD), box_target.cpu(), torch.ones_like(labels[i]).cpu(), labels[i].cpu(), train_set.class_names, figsize=(5,5))
                    plot_detections(tensor_to_image(images[i], IMG_MEAN, IMG_STD), boxes_plot.cpu(), scores_plot.cpu(), labels_plot.cpu(), val_set.class_names, figsize=(5,5))
                    

        val_loss /= float(batch_idx+1)
        val_loss_cls /= float(batch_idx+1)
        val_loss_reg /= float(batch_idx+1)
        val_loss_ctr /= float(batch_idx+1)
    
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss}, val loss NN total: {val_loss}, val loss CLS: {val_loss_cls},  val loss Reg: {val_loss_reg}, val loss ctr: {val_loss_ctr}")

        if SAVE_MODEL:
            os.makedirs(folder_path, exist_ok=True)
            # Save model and params
            json_params_epoch = json_params.copy()
            json_params_epoch["epoch"] = epoch
            json_params_epoch["train_loss"] = train_loss
            json_params_epoch["val_loss"] = val_loss
            json_params_epoch["val_loss_cls"] = val_loss_cls
            json_params_epoch["val_loss_reg"] = val_loss_reg
            json_params_epoch["val_loss_ctr"] = val_loss_ctr
            model_path = os.path.join(folder_path,f"model_{epoch}.pth")
            json_path = os.path.join(folder_path,f"params_{epoch}.json")
            torch.save(model_head.state_dict(), model_path)
            with open(json_path, "w") as outfile:
                json.dump(json_params_epoch, outfile)
    
print("Training finished.")