In [None]:
from src.datasets import fetch_dataset
from src.model_backbone import DinoBackbone
from src.model_head import LiteFlowHead
from src.loss import sequence_loss
import config.config as cfg
from src.common import tensor_to_image
from src.utils import flow_to_image

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import datetime
import json
import sys
import copy
import os
import time

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

################ LOAD ALL THE PARAMETERS #############################
# DATASET PARAMETERS
DATASET_NAME = cfg.DATASET_NAME
DATASET_LOCATIONS = cfg.DATASET_LOCATIONS
IMG_SIZE = cfg.IMG_SIZE
PATCH_SIZE = cfg.PATCH_SIZE
PROB_AUGMENT_TRAINING = cfg.PROB_AUGMENT_TRAINING
PROB_AUGMENT_VALID = cfg.PROB_AUGMENT_VALID
IMG_MEAN = cfg.IMG_MEAN
IMG_STD = cfg.IMG_STD

# MODEL PARAMETERS
DINOV3_DIR = cfg.DINOV3_DIR
DINO_MODEL = cfg.DINO_MODEL
DINO_WEIGHTS = cfg.DINO_WEIGHTS
MODEL_TO_NUM_LAYERS = cfg.MODEL_TO_NUM_LAYERS
MODEL_TO_EMBED_DIM = cfg.MODEL_TO_EMBED_DIM

# TRAINING PARAMETERS
BATCH_SIZE = cfg.BATCH_SIZE

LEARNING_RATE = cfg.LEARNING_RATE
WEIGHT_DECAY = cfg.WEIGHT_DECAY
NUM_EPOCHS = cfg.NUM_EPOCHS
NUM_SAMPLES_PLOT = cfg.NUM_SAMPLES_PLOT

LOAD_MODEL = cfg.LOAD_MODEL
SAVE_MODEL = cfg.SAVE_MODEL
MODEL_PATH_TRAIN_LOAD = cfg.MODEL_PATH_TRAIN_LOAD
RESULTS_PATH = cfg.RESULTS_PATH

train_set = fetch_dataset(DATASET_NAME, DATASET_LOCATIONS, "train", IMG_SIZE, IMG_MEAN, IMG_STD, PROB_AUGMENT_TRAINING)
train_dataloader = DataLoader(train_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True)

val_set = fetch_dataset(DATASET_NAME, DATASET_LOCATIONS, "val", IMG_SIZE, IMG_MEAN, IMG_STD, PROB_AUGMENT_VALID)
val_dataloader = DataLoader(val_set, batch_size = BATCH_SIZE, num_workers=8, shuffle=True)

dino_model = torch.hub.load(
        repo_or_dir=DINOV3_DIR,
        model=DINO_MODEL,
        source="local",
        weights=DINO_WEIGHTS
)
n_layers_dino = MODEL_TO_NUM_LAYERS[DINO_MODEL]
embed_dim = MODEL_TO_EMBED_DIM[DINO_MODEL]

dino_backbone = DinoBackbone(dino_model, n_layers_dino).to(device)

#model_head = Dinov3FlowHead(img_size=IMG_SIZE, in_ch=embed_dim).to(device)
model_head = LiteFlowHead(out_size=IMG_SIZE, 
                            in_channels = 384,
                            proj_channels = 256,
                            radius = 4,
                            fusion_channels = 448,
                            fusion_layers = 3, 
                            refinement_layers = 2).to(device)

optimizer = optim.Adam(model_head.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Load model
if LOAD_MODEL:
    model_head.load_state_dict(torch.load(MODEL_PATH_TRAIN_LOAD))
    print("Model successfully loaded!")

# Freeze parameters
for p in dino_backbone.parameters():
    p.requires_grad = False

n_params = sum([p.numel() for p in dino_backbone.parameters()]) + sum([p.numel() for p in model_head.parameters()])
print("Total number of parameters: ", n_params)
n_trainable_params = sum([p.numel() for p in dino_backbone.parameters() if p.requires_grad]) + sum([p.numel() for p in model_head.parameters() if p.requires_grad])
print("Total number of trainable parameters: ", n_trainable_params)
n_params_backbone = sum([p.numel() for p in dino_backbone.parameters()])
print("Number parameters backbone: ", n_params_backbone)

if SAVE_MODEL:
    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_path = f"{RESULTS_PATH}/{current_date}"
    
    json_params = { 
        "DATASET_NAME": DATASET_NAME,
        "IMG_SIZE" : IMG_SIZE, 
        "PATCH_SIZE" : PATCH_SIZE, 
        "PROB_AUGMENT_TRAINING": PROB_AUGMENT_TRAINING,
        "PROB_AUGMENT_VALID": PROB_AUGMENT_VALID,
        "DINO_MODEL": DINO_MODEL,
        "LEARNING_RATE" : LEARNING_RATE,
        "LOAD_MODEL" : LOAD_MODEL,
        "MODEL_PATH_TRAIN_LOAD" : MODEL_PATH_TRAIN_LOAD,

    }

In [None]:
# DINO backbone always in eval mode
dino_backbone.eval()

for epoch in range(NUM_EPOCHS):
    ##################### TRAIN #######################
    model_head.train()
    train_loss = 0.0
    
    for batch_idx, (im1, im2, flow_gt, valid) in enumerate(tqdm(train_dataloader)):
        im1 = im1.to(device, dtype=torch.float)
        im2 = im2.to(device, dtype=torch.float)
        flow_gt = flow_gt.to(device)
        valid = valid.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        feat1 = dino_backbone(im1)
        feat2 = dino_backbone(im2)
        flow_pred = model_head(feat1, feat2)

        # Check if flow_pred has NaNs
        if torch.isnan(flow_pred).any():
            print("flow_pred contains NaN values in batch_idx: ", batch_idx)

        # Calculate loss
        loss, metrics = sequence_loss(flow_pred, flow_gt, valid)

        # Check if loss is NaN
        if torch.isnan(loss):
            print("loss is NaN in batch_idx:", batch_idx)
            print("flow_pred min/max:", flow_pred.min().item(), flow_pred.max().item())
            print("flow_gt min/max:", flow_gt.min().item(), flow_gt.max().item())
            print("valid sum:", valid.sum().item())
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping and Optimize
        # clip all gradients to max norm 5.0
        torch.nn.utils.clip_grad_norm_(model_head.parameters(), max_norm=5.0)
        optimizer.step()
        
        train_loss += loss.item()

        # Show error
        if (batch_idx % 250 == 0 and batch_idx > 0):
            print(f"Epoch {epoch+1}, batch {batch_idx}, Loss: {train_loss/(batch_idx+1)}")

        # Plot
        if (batch_idx % 500 == 0 and batch_idx > 0):
            with torch.no_grad():
                for i in range(min(NUM_SAMPLES_PLOT, BATCH_SIZE)):
                    im1_plot = tensor_to_image(im1[i], IMG_MEAN, IMG_STD)
                    im2_plot = tensor_to_image(im2[i], IMG_MEAN, IMG_STD)
                    flow_gt_plot = flow_to_image(flow_gt[i].permute(1,2,0).cpu().numpy())
                    flow_pred_plot = flow_to_image(flow_pred[i].permute(1,2,0).cpu().numpy())
                    # Put your images in a list
                    images = [im1_plot, im2_plot, flow_gt_plot, flow_pred_plot]
                    titles = ["Image 1", "Image 2", "Flow GT", "Flow Pred"]
                    
                    fig, axes = plt.subplots(1, 4, figsize=(16, 4))  # 1 row, 4 columns
                    
                    for ax, img, title in zip(axes, images, titles):
                        ax.imshow(img)
                        ax.set_title(title)
                        ax.axis("off")
                    
                    plt.tight_layout()
                    plt.show()


    train_loss /= float(batch_idx+1)

    ##################### VALIDATION #######################
    model_head.eval()
    val_loss = 0.0

    with torch.no_grad():
        for batch_idx, (im1, im2, flow_gt, valid) in enumerate(tqdm(val_dataloader)):
            im1 = im1.to(device, dtype=torch.float)
            im2 = im2.to(device, dtype=torch.float)
            flow_gt = flow_gt.to(device)
            valid = valid.to(device)

            # Forward pass
            feat1 = dino_backbone(im1)
            feat2 = dino_backbone(im2)
            flow_pred = model_head(feat1, feat2)
    
            # Calculate loss
            loss, metrics = sequence_loss(flow_pred, flow_gt, valid)
            
            # Total loss
            val_loss += loss.item()

            # Plot
            if (batch_idx==0):
                for i in range(min(NUM_SAMPLES_PLOT, BATCH_SIZE)):
                    im1_plot = tensor_to_image(im1[i], IMG_MEAN, IMG_STD)
                    im2_plot = tensor_to_image(im2[i], IMG_MEAN, IMG_STD)
                    flow_gt_plot = flow_to_image(flow_gt[i].permute(1,2,0).cpu().numpy())
                    flow_pred_plot = flow_to_image(flow_pred[i].permute(1,2,0).cpu().numpy())
                    # Put your images in a list
                    images = [im1_plot, im2_plot, flow_gt_plot, flow_pred_plot]
                    titles = ["Image 1", "Image 2", "Flow GT", "Flow Pred"]
                    
                    fig, axes = plt.subplots(1, 4, figsize=(16, 4))  # 1 row, 4 columns
                    
                    for ax, img, title in zip(axes, images, titles):
                        ax.imshow(img)
                        ax.set_title(title)
                        ax.axis("off")
                    
                    plt.tight_layout()
                    plt.show()

        val_loss /= float(batch_idx+1)

        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss}, val loss NN: {val_loss}")
    
    if SAVE_MODEL:
        os.makedirs(folder_path, exist_ok=True)
        # Save model and params
        json_params_epoch = json_params.copy()
        json_params_epoch["epoch"] = epoch
        json_params_epoch["train_loss"] = train_loss
        json_params_epoch["val_loss"] = val_loss
        model_path = os.path.join(folder_path,f"model_{epoch}.pth")
        json_path = os.path.join(folder_path,f"params_{epoch}.json")
        torch.save(model_head.state_dict(), model_path)
        with open(json_path, "w") as outfile:
            json.dump(json_params_epoch, outfile)
    
print("Training finished.")