In [1]:
from src.dataset_coco import DatasetCOCO
from src.model_backbone import DinoBackbone
from src.model_head import DinoFCOSHead
from src.loss import compute_loss
#from src.draw_samples_training import draw_samples_training
import config.config as cfg
from src.utils import collate_fn

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
import json
import sys
import copy
import os

%matplotlib inline

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

################ LOAD ALL THE PARAMETERS #############################
# DATASET PARAMETERS
COCO_ROOT = cfg.COCO_ROOT # Root to the folder with the prepared data
IMG_SIZE = cfg.IMG_SIZE
PATCH_SIZE = cfg.PATCH_SIZE
PROB_AUGMENT_TRAINING = cfg.PROB_AUGMENT_TRAINING
PROB_AUGMENT_VALID = cfg.PROB_AUGMENT_VALID
IMG_MEAN = cfg.IMG_MEAN
IMG_STD = cfg.IMG_STD
    
# MODEL PARAMETERS
DINOV3_DIR = cfg.DINOV3_DIR
FPN_CH = cfg.FPN_CH
DINO_MODEL = cfg.DINO_MODEL
MODEL_TO_NUM_LAYERS = cfg.MODEL_TO_NUM_LAYERS
MODEL_TO_EMBED_DIM = cfg.MODEL_TO_EMBED_DIM
N_LAYERS_UNFREEZE = cfg.N_LAYERS_UNFREEZE

# TRAINING PARAMETERS
BATCH_SIZE = cfg.BATCH_SIZE
FOCAL_ALPHA = cfg.FOCAL_ALPHA
FOCAL_GAMMA = cfg.FOCAL_GAMMA
WEIGHT_REG = cfg.WEIGHT_REG
WEIGHT_CTR = cfg.WEIGHT_CTR

LEARNING_RATE = cfg.LEARNING_RATE
NUM_EPOCHS = cfg.NUM_EPOCHS
NUM_SAMPLES_PLOT = cfg.NUM_SAMPLES_PLOT

LOAD_MODEL = cfg.LOAD_MODEL
SAVE_MODEL = cfg.SAVE_MODEL
MODEL_PATH_TRAIN_LOAD = cfg.MODEL_PATH_TRAIN_LOAD
RESULTS_PATH = cfg.RESULTS_PATH

train_set = DatasetCOCO(COCO_ROOT, "train", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_TRAINING, IMG_MEAN, IMG_STD)
train_dataloader = DataLoader(train_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True, collate_fn=collate_fn)

val_set = DatasetCOCO(COCO_ROOT, "val", IMG_SIZE, PATCH_SIZE, PROB_AUGMENT_VALID, IMG_MEAN, IMG_STD)
val_dataloader = DataLoader(val_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True, collate_fn=collate_fn)

num_classes = len(train_set.class_names)

dino_model = torch.hub.load(
        repo_or_dir=DINOV3_DIR,
        model="dinov3_vits16plus",
        source="local"
)
n_layers_dino = MODEL_TO_NUM_LAYERS[DINO_MODEL]
embed_dim = MODEL_TO_EMBED_DIM[DINO_MODEL]

dino_backbone = DinoBackbone(dino_model, n_layers_dino).to(device)
model_head = DinoFCOSHead(backbone_out_channels=embed_dim, fpn_channels=FPN_CH, num_classes=num_classes).to(device)

optimizer = optim.Adam(model_head.parameters(), lr=LEARNING_RATE)

# Load model
if LOAD_MODEL:
    model_head.load_state_dict(torch.load(MODEL_PATH_TRAIN_LOAD))
    print("Model successfully loaded!")


# Freeze parameters
for p in dino_backbone.parameters():
    p.requires_grad = False

# Unfreeze last N transformer blocks
if N_LAYERS_UNFREEZE > 0:
    # Unfreeze the last norm layer
    for p in dino_backbone.dino.norm.parameters():
        p.requires_grad = True
    for block in dino_backbone.dino.blocks[-N_LAYERS_UNFREEZE:]:
        for param in block.parameters():
            param.requires_grad = True

n_params = sum([p.numel() for p in dino_backbone.parameters()]) + sum([p.numel() for p in model_head.parameters()])
print("Total number of parameters: ", n_params)
n_trainable_params = sum([p.numel() for p in dino_backbone.parameters() if p.requires_grad]) + sum([p.numel() for p in model_head.parameters() if p.requires_grad])
print("Total number of trainable parameters: ", n_trainable_params)
n_params_backbone = sum([p.numel() for p in dino_backbone.parameters()])
print("Number parameters backbone: ", n_params_backbone)

if SAVE_MODEL:
    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_path = f"{RESULTS_PATH}/{current_date}"
    
    json_params = { 
        "IMG_SIZE" : IMG_SIZE, 
        "PATCH_SIZE" : PATCH_SIZE, 
        "PROB_AUGMENT_TRAINING": PROB_AUGMENT_TRAINING,
        "PROB_AUGMENT_VALID": PROB_AUGMENT_VALID,
        "DINO_MODEL": DINO_MODEL,
        "N_LAYERS_UNFREEZE": N_LAYERS_UNFREEZE,
        "FPN_CH": FPN_CH,
        "FOCAL_ALPHA" : FOCAL_ALPHA,
        "FOCAL_GAMMA" : FOCAL_GAMMA,
        "WEIGHT_REG" : WEIGHT_REG,
        "WEIGHT_CTR": WEIGHT_CTR,
        "LEARNING_RATE" : LEARNING_RATE,
        "LOAD_MODEL" : LOAD_MODEL,
        "MODEL_PATH_TRAIN_LOAD" : MODEL_PATH_TRAIN_LOAD,

    }

Using device:  cuda
loading annotations into memory...
Done (t=7.50s)
creating index...
index created!
Total number of classes: 80
loading annotations into memory...
Done (t=1.29s)
creating index...
index created!
Total number of classes: 80
Total number of parameters:  33235477
Total number of trainable parameters:  4538005
Number parameters backbone:  28697472


In [None]:
# DINO backbone always in eval mode
dino_backbone.eval()

for epoch in range(NUM_EPOCHS):
    ##################### TRAIN #######################
    model_head.train()
    train_loss = 0.0
    
    for batch_idx, (image, boxes, labels) in enumerate(tqdm(train_dataloader)):
        image = image.to(device, dtype=torch.float) #, boxes.to(device, dtype=torch.float), labels.to(device, dtype=torch.int)
        boxes = [box.to(device, dtype=torch.float) for box in boxes]
        labels = [label.to(device, dtype=torch.int) for label in labels]
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        feat = dino_backbone(image)
        outputs = model_head(feat)

        first_stride = IMG_SIZE / outputs['cls'][0].shape[2]
        strides = [first_stride, first_stride*2, first_stride*4]
        
        # Calculate loss
        loss = compute_loss(outputs, boxes, labels, image.shape[2:], strides, FOCAL_ALPHA, FOCAL_GAMMA, WEIGHT_REG, WEIGHT_CTR)
        
        # Backward pass
        loss[0].backward()
        
        # Gradient clipping and Optimize
        # clip all gradients to max norm 5.0
        torch.nn.utils.clip_grad_norm_(model_head.parameters(), max_norm=5.0)
        optimizer.step()
        
        train_loss += loss[0].item()

        if (batch_idx % 200 == 0 and batch_idx > 0):
            print(f"Epoch {epoch+1}, batch {batch_idx}, Loss: {train_loss/(batch_idx+1)}, cls loss: {loss[1].item()}, regression loss: {loss[2].item()}, ctr loss: {loss[3].item()}")

        #if (batch_idx % 10000 == 0 and batch_idx > 0):
            #draw_samples_training(template, search, torch.sigmoid(pred_heatmap), pred_bbox, heatmap, bbox, train_set.mean, train_set.std, THRESHOLD_CLS, NUM_SAMPLES_PLOT, video_template_name, video_search_name)
    train_loss /= float(batch_idx+1)
    
    ##################### VALIDATION #######################
    model_head.eval()
    val_loss = 0.0
    val_loss_cls = 0.0
    val_loss_reg = 0.0
    val_loss_ctr = 0.0

    with torch.no_grad():
        for batch_idx, (image, boxes, labels) in enumerate(tqdm(val_dataloader)):
            image = image.to(device, dtype=torch.float) #, boxes.to(device, dtype=torch.float), labels.to(device, dtype=torch.int)
            boxes = [box.to(device, dtype=torch.float) for box in boxes]
            labels = [label.to(device, dtype=torch.int) for label in labels]
            
            # Forward pass
            feat = dino_backbone(image)
            outputs = model_head(feat)
    
            first_stride = IMG_SIZE / outputs['cls'][0].shape[2]
            strides = [first_stride, first_stride*2, first_stride*4]
            
            # Calculate loss
            loss = compute_loss(outputs, boxes, labels, image.shape[2:], strides, FOCAL_ALPHA, FOCAL_GAMMA, WEIGHT_REG, WEIGHT_CTR)
    
            val_loss += loss[0].item()
            val_loss_cls += loss[1].item()
            val_loss_reg += loss[2].item()
            val_loss_ctr += loss[3].item()
    
            #if (batch_idx == 0):
                #draw_samples_training(template, search, torch.sigmoid(pred_heatmap), pred_bbox, heatmap, bbox, train_set.mean, train_set.std, THRESHOLD_CLS, NUM_SAMPLES_PLOT, video_template_name, video_search_name)
    
        val_loss /= float(batch_idx+1)
        val_loss_cls /= float(batch_idx+1)
        val_loss_reg /= float(batch_idx+1)
        val_loss_ctr /= float(batch_idx+1)
    
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss}, val loss NN total: {val_loss}, val loss CLS: {val_loss_cls},  val loss Reg: {val_loss_reg}, val loss ctr: {val_loss_ctr}")

        if SAVE_MODEL:
            os.makedirs(folder_path, exist_ok=True)
            # Save model and params
            json_params_epoch = json_params.copy()
            json_params_epoch["epoch"] = epoch
            json_params_epoch["train_loss"] = train_loss
            json_params_epoch["val_loss"] = val_loss
            json_params_epoch["val_loss_cls"] = val_loss_cls
            json_params_epoch["val_loss_reg"] = val_loss_reg
            json_params_epoch["val_loss_ctr"] = val_loss_ctr
            model_path = os.path.join(folder_path,f"model_{epoch}.pth")
            json_path = os.path.join(folder_path,f"params_{epoch}.json")
            torch.save(model_head.state_dict(), model_path)
            with open(json_path, "w") as outfile:
                json.dump(json_params_epoch, outfile)
    
print("Training finished.")

  5%|█▎                      | 201/3697 [02:37<44:52,  1.30it/s]

Epoch 1, batch 200, Loss: 1.5401874098611708, cls loss: 0.4163675904273987, regression loss: 0.9999523162841797, ctr loss: 7.203803397715092e-05


 11%|██▌                     | 401/3697 [05:12<43:20,  1.27it/s]

Epoch 1, batch 400, Loss: 1.4584570926918354, cls loss: 0.3413376212120056, regression loss: 0.9937861561775208, ctr loss: 0.00013686894089914858


 16%|███▉                    | 601/3697 [07:49<39:29,  1.31it/s]

Epoch 1, batch 600, Loss: 1.4171153473576372, cls loss: 0.3267313241958618, regression loss: 0.9886676073074341, ctr loss: 0.00012433268420863897


 22%|█████▏                  | 801/3697 [10:24<37:06,  1.30it/s]

Epoch 1, batch 800, Loss: 1.391193706742238, cls loss: 0.3618829548358917, regression loss: 0.9837740063667297, ctr loss: 0.00010283814481226727


 23%|█████▍                  | 836/3697 [10:51<37:27,  1.27it/s]