In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
from typing import List

import numpy as np
import torch
import torch.nn as nn
import time
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

from clort.data import ArgoCL, ArgoCl_collate_fxn
from clort.model import ContrastiveLoss, MemoryBank, MemoryBankInfer, MultiViewEncoder

# torch.autograd.set_detect_anomaly(True)

In [6]:
# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    project="CLORT",
    
    resume=False,

    # track hyperparameters and run metadata
    config={
    "architecture": "Multi_View Encoder : Renet Single View Encoder",
    "training data": ["train4"],
    "validation data": ["train4"],
    "saved model": None,
    "batch size": 1,
    "temporal horizon": 1,
    "temporal overlap": 0,
    "max objects": None,
    "static contrast": True,
    "Normalization": "None",
    "Normalization momemtum": 0.1,
    "n_epochs": 30
    }
)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669268050001542, max=1.0…

In [7]:
root: str = "../../../datasets/argoverse-tracking/argov1_proc/"
train_splits: List[str] = run.config["training data"]
val_splits: List[str] = run.config["validation data"]
model_save_dir: str = '/home/shivam/CLORT_MV/'
load_saved_model: str | None = run.config["saved model"]
batch_size: int = run.config["batch size"]
max_objects: int = run.config["max objects"]
th: int = run.config["temporal horizon"]
to: int = run.config["temporal overlap"]
nw: int = 0
model_device: torch.device | str = 'cuda'
memory_device: torch.device | str = 'cpu'
static_contrast: bool = run.config["static contrast"]
n_epochs: int = run.config['n_epochs']

In [8]:
train_dataset = ArgoCL(root,
                       temporal_horizon=th,
                       temporal_overlap=to,
                       max_objects=max_objects,
                       distance_threshold=(0, 50),
                       splits=train_splits, img_size=(224, 224),
                       point_cloud_size=[20, 50, 100, 250, 500, 1000, 1500],
                       in_global_frame=True, pivot_to_first_frame=True,
                       image=True, pcl=True, bbox=True)

val_dataset = ArgoCL(root,
                    temporal_horizon=th,
                    temporal_overlap=to,
                    distance_threshold=(0, 50),
                    splits=val_splits, img_size=(224, 224),
                    point_cloud_size=[20, 50, 100, 250, 500, 1000, 1500],
                    in_global_frame=True, pivot_to_first_frame=True,
                    image=True, pcl=True, bbox=True)

In [9]:
train_dl = DataLoader(train_dataset, batch_size, shuffle=True,
                    collate_fn=ArgoCl_collate_fxn, num_workers=nw)

val_dl = DataLoader(val_dataset, 1, shuffle=False,
                collate_fn=ArgoCl_collate_fxn, num_workers=nw)

In [10]:
n_features = 256

In [11]:
mv_enc = MultiViewEncoder(out_dim=n_features)
mv_enc = mv_enc.to(model_device)

In [12]:
mb = MemoryBank(train_dataset.n_tracks, n_features, 5,
                    alpha=torch.tensor([0.5, 0.4, 0.3, 0.2, 0.1], dtype=torch.float32, device=memory_device),
                    device=memory_device)

cl = ContrastiveLoss(temp=0.05, static_contrast=static_contrast)

mb_infer = MemoryBank(val_dataset.n_tracks, n_features, 5,
                    alpha=torch.tensor([0.5, 0.4, 0.3, 0.2, 0.1], dtype=torch.float32, device=memory_device),
                    device=memory_device)

cl_infer = ContrastiveLoss(temp=0.05, static_contrast=static_contrast)

mb_primed = False

In [15]:
optimizer = torch.optim.AdamW(
                        params=[
                            {'params' : mv_enc.sv_enc1.parameters(), 'lr': 1e-4, "weight_decay":1e-3},
                            {'params': mv_enc.sv_enc2.parameters(), 'lr': 1e-5, "weight_decay":1e-4},
                            {'params': mv_enc.sv_enc3.parameters(), 'lr': 1e-5, "weight_decay":1e-4},
                            {'params': mv_enc.gat.parameters(), 'lr': 1e-5, "weight_decay":1e-4},
                            {'params': mv_enc.projection_head.parameters(), 'lr': 1e-5, "weight_decay":1e-4}
                            ], lr = 1e-4, weight_decay=1e-3
                    )

In [16]:
# Load model from file
last_epoch = -1

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1, last_epoch=last_epoch)

training_loss: List[float] = []
validation_loss: List[float] = []

In [17]:
if wandb.run.resumed:
    print(f'Loading model from file: {load_saved_model = }')
    ckpt = torch.load(wandb.restore(load_saved_model))
    mv_enc.load_state_dict(ckpt['mv_enc'])
    optimizer.load_state_dict(ckpt['optimizer'])
    lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
    training_loss = ckpt['train_loss']
    validation_loss = ckpt['val_loss']

In [18]:
def train(epoch, mv_enc, train_dl, optimizer, criterion, mem_bank, log_step=100, mb_priming = False):
    mv_enc.train() # Enable training
    
    training_loss = []
    
    # Training loop
    for itr, (_, _, imgs, imgs_sz, _, track_idxs, _, _, _) in (t_bar := tqdm(enumerate(train_dl))):
        optimizer.zero_grad()

        # pcls = pcls.to(model_device)
        imgs = imgs.to(model_device)
        track_idxs = torch.from_numpy(track_idxs.astype(np.int32))
        # bboxs = bboxs.to(model_device)

        mv_e = mv_enc(imgs, imgs_sz)
        
        if mb_priming:
            t_bar.set_description('Priming')
            
            loss = criterion(mv_e, track_idxs, mem_bank.get_memory())
            loss.backward() # clear_graph
            
            mem_bank.update(mv_e.detach().cpu(), track_idxs) # Update memory bank
            continue
        
        loss = criterion(mv_e, track_idxs, mem_bank.get_memory())
        training_loss.append(loss.numpy(force=True).item())

        loss.backward()

        optimizer.step()

        mem_bank.update(mv_e.detach().cpu(), track_idxs) # Update memory bank
        
        #t_bar.set_description(f'{epoch = } and {itr = } : Mean Training loss : {np.mean(training_loss[-los_step if itr>los_step else -(itr+1):])}')
        
        if itr%log_step == log_step-1:
            t_bar.set_description(f'{epoch = } and {itr = } : Mean Training loss : {np.mean(training_loss[-log_step:])}')

            wandb.log({'epoch': epoch+1, 'itr': itr+1,
                        'Training Loss': np.mean(training_loss[-log_step:])
                      })
        
    return training_loss

In [19]:
def val(epoch, mv_enc, train_dl, criterion, mem_bank, log_step=100):
    mv_enc.eval() # Enable training
    
    validation_loss = []
    
    # Validation loop
    with torch.no_grad():
        for itr, (_, _, imgs, imgs_sz, _, track_idxs, _, _, _) in (v_bar := tqdm(enumerate(val_dl))):
            # pcls = pcls.to(model_device)
            imgs = imgs.to(model_device)
            track_idxs = torch.from_numpy(track_idxs.astype(np.int32))
            # bboxs = bboxs.to(model_device)

            mv_e = mv_enc(imgs, imgs_sz)

            loss = criterion(mv_e, track_idxs, mem_bank.get_memory())
            validation_loss.append(loss.numpy(force=True).item())

            mem_bank.update(mv_e.detach().cpu(), track_idxs) # Update memory bank

            #v_bar.set_description(f'{epoch = } and {itr = } : Mean Training loss : {np.mean(training_loss[-100 if itr>100 else -(itr+1):])}')

            if itr%(log_step) == log_step-1:
                v_bar.set_description(f'{epoch = } and {itr = } : Mean Validation loss : {np.mean(validation_loss[-log_step:])}')

                wandb.log({'Epoch': epoch+1, 'Iteration': itr+1,
                            'Validation Loss': np.mean(validation_loss[-log_step:])
                          })
        
    return validation_loss

In [20]:
last_epoch = lr_scheduler.last_epoch

In [21]:
last_epoch

0

In [22]:
# for module in mv_enc.modules():
#     print(f'{len(list(module.modules())) = }')

In [23]:
# for child in mv_enc.sv_enc1.children():
#     for child_ in child.children():
#         for child__ in child_.children():
#             for child___ in child__.children():
#                 if type(child___) == nn.BatchNorm2d:
#                     child___.track_running_stats = False
#                     print(f'{child___ = }')

In [24]:
# for child in mv_enc.sv_enc2.children():
#     if type(child) == nn.LayerNorm:
#         print(dict(child.named_parameters()))

In [25]:
if not mb_primed:
    train(-1, mv_enc, train_dl, optimizer, cl, mb, log_step=100, mb_priming=True)
    mb_primed = True

for epoch in range(last_epoch, n_epochs):
    model_path = os.path.join(model_save_dir, f'model.pth')
    
    train_loss = train(epoch, mv_enc, train_dl, optimizer, cl, mb, log_step=100, mb_priming=False)
    
    ###################################################################################
    ### Validation loss
    if epoch%10 == 9:
        val_loss = val(epoch, mv_enc, val_dl, cl_infer, mb_infer, log_step=100) 
    ### Validation loss
    ###################################################################################
    lr_scheduler.step() # Step Learning rate

    model_info = {
        'EPOCH': epoch,
        'mv_enc': mv_enc.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'train_loss': training_loss,
        'val_loss': validation_loss
    }
    
    os.makedirs(model_save_dir, exist_ok=True)
    
    torch.save(model_info, model_path)
    
    wandb.save(model_path)

Priming: : 914it [03:09,  4.83it/s]
epoch = 1 and itr = 899 : Mean Training loss : 2.063762639909983: : 914it [03:35,  4.24it/s] 
epoch = 2 and itr = 899 : Mean Training loss : 1.5817130291834474: : 914it [03:24,  4.46it/s]
epoch = 3 and itr = 899 : Mean Training loss : 2.0123099377006293: : 914it [03:22,  4.50it/s]
epoch = 4 and itr = 899 : Mean Training loss : 2.627462146356702: : 914it [03:28,  4.39it/s] 
epoch = 5 and itr = 899 : Mean Training loss : 2.082427296638489: : 914it [03:21,  4.54it/s] 
epoch = 6 and itr = 899 : Mean Training loss : 1.9670758653990925: : 914it [03:40,  4.14it/s]
epoch = 7 and itr = 199 : Mean Training loss : 2.374086647182703: : 270it [00:59,  4.51it/s]


KeyboardInterrupt: 

In [26]:
wandb.finish()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Training Loss,█▂▃▇▄▆▄▆▆▆▅▄▃▄▄▆▆▃▆▃▁▅▄▆▅▇▄▄▃▄▅▇▃▆▂▄▅▆▃▅
epoch,▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█
itr,▁▂▃▅▅▇█▁▃▄▅▆▇▁▂▄▅▅▇█▂▃▅▅▆█▁▃▄▅▆▇▁▂▃▅▅▇█▂

0,1
Training Loss,2.37409
epoch,8.0
itr,200.0
