In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from typing import List

import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

from clort import ArgoCL, ArgoCl_collate_fxn
from clort.model import ContrastiveLoss, MemoryBank, MemoryBankInfer, MultiViewEncoder

# torch.autograd.set_detect_anomaly(True)

  warn(f"Failed to load image Python extension: {e}")


In [3]:
# start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="CLORT",

#     # track hyperparameters and run metadata
#     config={
#     "architecture": "Multi_View Encoder",
#     "dataset": "ArgoCL : Train1",
#     "epochs": 30,
#     }
# )

In [4]:
root: str = "../../../datasets/argoverse-tracking/argov1_proc/"
splits: List[str] = ['train4']
model_save_dir: str = '~/.tmp/CLORT/'
load_saved_model: str | None = None
batch_size: int = 1
th: int = 1
to: int = 0
nw: int = 0
model_device: torch.device | str = 'cuda'
memory_device: torch.device | str = 'cpu'
static_contrast: bool = False
n_epochs: int = 30

In [5]:
train_dataset = ArgoCL(root,
                       temporal_horizon=th,
                       temporal_overlap=to,
                       distance_threshold=(0, 100),
                       splits=splits, img_size=(224, 224),
                       point_cloud_size=[20, 50, 100, 250, 500, 1000, 1500],
                       in_global_frame=True, pivot_to_first_frame=True,
                       image=True, pcl=True, bbox=True)

val_dataset = ArgoCL(root,
                    temporal_horizon=1,
                    temporal_overlap=0,
                    distance_threshold=(0, 100),
                    splits=['val'], img_size=(224, 224),
                    point_cloud_size=[20, 50, 100, 250, 500, 1000, 1500],
                    in_global_frame=True, pivot_to_first_frame=True,
                    image=True, pcl=True, bbox=True)

In [5]:
train_dl = DataLoader(train_dataset, batch_size, shuffle=False,
                    collate_fn=ArgoCl_collate_fxn, num_workers=nw)

val_dl = DataLoader(val_dataset, 1, shuffle=False,
                collate_fn=ArgoCl_collate_fxn, num_workers=nw)

In [6]:
n_features = 256

In [7]:
mv_enc = MultiViewEncoder(out_dim=n_features)
mv_enc = mv_enc.to(model_device)

In [8]:
mb = MemoryBank(train_dataset.n_tracks, n_features, 5,
                    alpha=torch.tensor([0.5, 0.4, 0.3, 0.2, 0.1], dtype=torch.float32, device=memory_device),
                    device=memory_device)

cl = ContrastiveLoss(temp=0.05, static_contrast=static_contrast)

In [9]:
optimizer = torch.optim.AdamW(
                        params=[
                            {'params' : mv_enc.sv_enc1.parameters(), 'lr': 1e-6},
                            {'params': mv_enc.sv_enc2.parameters(), 'lr': 1e-4},
                            {'params': mv_enc.sv_enc3.parameters(), 'lr': 1e-4},
                            {'params': mv_enc.gat.parameters(), 'lr': 1e-4},
                            {'params': mv_enc.projection_head.parameters(), 'lr': 1e-4}
                            ], lr = 1e-4, weight_decay=1e-3
                    )

In [10]:
# Load model from file
last_epoch = -1

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1, last_epoch=last_epoch)

training_loss: List[float] = []
validation_loss: List[float] = []

In [11]:
if load_saved_model is not None:
    ckpt = torch.load(load_saved_model)
    mv_enc.load_state_dict(ckpt['mv_enc'])
    optimizer.load_state_dict(ckpt['optimizer'])
    lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
    training_loss = ckpt['train_loss']
    validation_loss = ckpt['val_loss']

In [12]:
for epoch in range(last_epoch+1, n_epochs):
    model_path = os.path.join(model_save_dir, f'model_{epoch}.pth')

    # Training loop
    for itr, (_, _, imgs, imgs_sz, _, track_idxs, _, _, _) in tqdm(enumerate(train_dl)):
        mv_enc.train() # Enable training
        optimizer.zero_grad()

        # pcls = pcls.to(model_device)
        imgs = imgs.to(model_device)
        track_idxs = torch.from_numpy(track_idxs.astype(np.int32))
        # bboxs = bboxs.to(model_device)

        mv_e = mv_enc(imgs, imgs_sz)

        loss = cl(mv_e, track_idxs, mb.get_memory())
        training_loss.append(loss.numpy(force=True).item())

        loss.backward()

        optimizer.step()

        if itr%10 == 9:
#             print(f'{epoch = } and {itr = } : Mean Training loss : {np.mean(training_loss[-10:])}')

    ###################################################################################
            ### Validation loss
            mb_infer = MemoryBankInfer(val_dataset.n_tracks, n_features, 5, 3, 'cpu')

            cl_infer = ContrastiveLoss(static_contrast=False)

            val_loss = 0.0

            mv_enc.eval() # Enable inference

            with torch.no_grad():
                for _, (_, _, imgs, imgs_sz, _, track_idxs, _, _, _) in enumerate(val_dl):
                    # pcls = pcls.to(model_device)
                    imgs = imgs.to(model_device)
                    track_idxs = torch.from_numpy(track_idxs.astype(np.int32))
                    # bboxs = bboxs.to(model_device)

                    mv_e : torch.Tensor = mv_enc(imgs, imgs_sz)
                    loss : torch.Tensor = cl_infer(mv_e, track_idxs, mb_infer.get_memory())

                    val_loss += loss.detach().cpu().item()

                    mb_infer.update(mv_e.detach().cpu(), track_idxs)

            val_loss /= len(val_dl)
            validation_loss.append(val_loss)

#             print(f'{epoch = } and {itr = } : Mean Validation loss : {val_loss = }')

            wandb.log({'epoch': epoch+1, 'itr': itr+1,
                        'training_loss': np.mean(training_loss[-10:]),
                        'val_loss': val_loss})
            ### Validation loss
    ###################################################################################
    lr_scheduler.step() # Step Learning rate

    model_info = {
        'EPOCH': epoch,
        'mv_enc': mv_enc.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'train_loss': training_loss,
        'val_loss': validation_loss
    }

    torch.save(model_info, model_path)

9it [00:03,  2.73it/s]

epoch = 0 and itr = 9 : Mean Training loss : 2.2776680469512938



0it [00:00, ?it/s][A
2it [00:00, 11.40it/s][A
4it [00:00, 11.15it/s][A
6it [00:00, 11.31it/s][A
8it [00:00, 11.61it/s][A
10it [00:00, 11.80it/s][A
12it [00:01, 12.08it/s][A
14it [00:01, 12.50it/s][A
16it [00:01, 12.77it/s][A
18it [00:01, 13.01it/s][A
20it [00:01, 13.54it/s][A
22it [00:01, 14.21it/s][A
24it [00:01, 14.86it/s][A
26it [00:01, 15.42it/s][A
28it [00:02, 14.94it/s][A
30it [00:02, 14.46it/s][A
32it [00:02, 14.07it/s][A
34it [00:02, 13.77it/s][A
36it [00:02, 13.34it/s][A
38it [00:02, 12.40it/s][A
40it [00:03, 11.31it/s][A
42it [00:03, 10.66it/s][A
44it [00:03, 10.26it/s][A
46it [00:03,  9.64it/s][A
47it [00:03,  9.37it/s][A
48it [00:04,  9.21it/s][A
49it [00:04,  8.96it/s][A
50it [00:04,  8.70it/s][A
51it [00:04,  8.73it/s][A
52it [00:04,  8.71it/s][A
53it [00:04,  8.78it/s][A
54it [00:04,  8.90it/s][A
55it [00:04,  8.95it/s][A
57it [00:05,  9.79it/s][A
59it [00:05,  9.81it/s][A
60it [00:05,  9.41it/s][A
61it [00:05,  9.50it/s][A
62it [00:

ValueError: need at least one array to concatenate