In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
from tqdm import tqdm

from torch.utils.data import DataLoader

from clort import ArgoCL
from clort import ArgoCl_collate_fxn

  warn(f"Failed to load image Python extension: {e}")


In [3]:
root = "../../../datasets/argoverse-tracking/argov1_proc"

In [4]:
dataset = ArgoCL(root, temporal_horizon=1, temporal_overlap=0, distance_threshold=(0, 100), splits=['train4'], img_size=(224, 224),
                  point_cloud_size=[20], in_global_frame=True, pivot_to_first_frame=True, image=True, pcl=True, bbox=True)
# [20, 50, 100, 250, 500, 1000, 1500]

In [5]:
# pcls, pcls_sz, imgs, imgs_sz, bboxs, track_idxs, cls_idxs, frame_sz = dataset[0]
dl = DataLoader(dataset, batch_size=1, collate_fn=ArgoCl_collate_fxn)

In [6]:
dl_it = iter(dl)

In [7]:
pcls, pcls_sz, imgs, imgs_sz, bboxs, track_idxs, cls_idxs, frame_sz, sample_sz = next(dl_it)

In [8]:
import importlib
from clort.model.encoders import MultiViewEncoder, PointCloudEncoder, MultiModalEncoder, CrossObjectEncoder
from timm import create_model
from torchviz import make_dot
from torchview import draw_graph

In [9]:
bboxs.shape

torch.Size([13, 8, 3])

In [10]:
mv_enc = MultiViewEncoder(out_dim=256).to('cuda:0')
pc_enc = PointCloudEncoder(out_dims=128).to('cuda:0')
mm_enc = MultiModalEncoder(mv_in_dim=256, pc_in_dim=128, out_dim=256).to('cuda:0')
cr_enc = CrossObjectEncoder(256, 128).to('cuda:0')

In [11]:
imgs = imgs.to('cuda:0')
pcls = pcls.to('cuda:0')
bboxs = bboxs.to('cuda:0')

In [12]:
mv_e = mv_enc(imgs, imgs_sz)
print(f'{mv_e.shape = }')

mv_e.shape = torch.Size([13, 256])


In [13]:
pc_e = pc_enc(pcls, pcls_sz, bboxs)
print(f'{pc_e.shape = }')

pc_e.shape = torch.Size([13, 128])


In [14]:
mm_e = mm_enc(mv_e, pc_e)
print(f'{mm_e.shape = }')

mm_e.shape = torch.Size([13, 256])


In [15]:
cr_e = cr_enc(mm_e, frame_sz)
print(f'{cr_e.shape = }')

cr_e.shape = torch.Size([13, 128])


In [16]:
from clort import MemoryBank, MemoryBankInfer
from clort.model.ContrastiveLoss import ContrastiveLoss

In [29]:
mb = MemoryBank(dataset.n_tracks, 128, Q=5, alpha=torch.arange(1, 6, dtype=torch.float32).flip([0])/10, device=torch.device('cuda:0'))
mbinfer = MemoryBankInfer(dataset.n_tracks, 128, Q=5, t=3, device='cuda:0')

In [20]:
cl = ContrastiveLoss(static_contrast=True)

In [21]:
loss = cl(cr_e, torch.tensor(track_idxs.astype(int), dtype=torch.int32), mb.get_memory())

In [24]:
loss.detach().cpu().item()

3.357988119125366

In [56]:
# len(dataset)

914

In [25]:
loss.backward()

In [28]:
mb.update(cr_e.detach(), torch.tensor(track_idxs.tolist(), dtype=torch.int32))

In [44]:
mbinfer.update(torch.rand(cr_e.size(), device='cuda:0'), torch.tensor(track_idxs.tolist(), dtype=torch.int32))

In [45]:
mbinfer.get_reprs(torch.from_numpy(track_idxs.astype(int)))

tensor([[[ 0.0337,  0.0940,  0.0266,  ...,  0.0828,  0.1607,  0.0054],
         [ 0.0415,  0.1009,  0.0382,  ...,  0.0892,  0.1592,  0.0118],
         [ 0.0482,  0.1035,  0.0466,  ...,  0.0939,  0.1554,  0.0181],
         [ 0.0447,  0.0980,  0.0382,  ...,  0.0909,  0.1573,  0.0152],
         [ 0.0313,  0.0907,  0.0220,  ...,  0.0800,  0.1601,  0.0035]],

        [[ 0.0429,  0.1678, -0.0163,  ...,  0.0904,  0.0695,  0.1069],
         [ 0.0568,  0.1603, -0.0105,  ...,  0.0965,  0.0705,  0.1139],
         [ 0.0651,  0.1550, -0.0049,  ...,  0.0972,  0.0712,  0.1164],
         [ 0.0538,  0.1651, -0.0080,  ...,  0.0908,  0.0721,  0.1109],
         [ 0.0363,  0.1703, -0.0187,  ...,  0.0869,  0.0686,  0.1030]],

        [[ 0.0597,  0.1504, -0.0457,  ...,  0.1107, -0.0049,  0.0235],
         [ 0.0754,  0.1570, -0.0240,  ...,  0.1210,  0.0071,  0.0320],
         [ 0.0872,  0.1602, -0.0078,  ...,  0.1265,  0.0175,  0.0392],
         [ 0.0768,  0.1576, -0.0253,  ...,  0.1192,  0.0100,  0.0342],
  

In [46]:
cr_e

tensor([[-0.0259,  0.0561, -0.0279,  ...,  0.0289,  0.1386, -0.0400],
        [-0.0178,  0.1298, -0.0530,  ...,  0.0674,  0.0366,  0.0619],
        [-0.0264,  0.0831, -0.1193,  ...,  0.0494, -0.0678, -0.0252],
        ...,
        [-0.0393,  0.0615, -0.0967,  ...,  0.0097, -0.1211, -0.0463],
        [ 0.0349, -0.0104, -0.0475,  ..., -0.0390, -0.0741, -0.0316],
        [ 0.0489,  0.1094, -0.1338,  ...,  0.0071, -0.0629, -0.0189]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [52]:
def perturb(e):
    e = e*2.
    return e

In [53]:
a = perturb(cr_e)

In [54]:
cr_e == a

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0')

In [26]:
make_dot(pc_e).render('./out')
# # # model_graph = draw_graph(sv_enc, input_size=imgs.shape, expand_nested=False)
# # tuple(mv_e.size())
# mb.get_reprs(torch.tensor(track_idxs.tolist(), dtype=torch.int32))

'out.pdf'

In [1]:
# g = draw_graph(pc_enc, input_data=(pcls, pcls_sz, bboxs))

In [15]:
d = make_dot(s, params=dict(mv_enc.named_parameters()))

In [16]:
d.render('./out')

'out.pdf'

In [38]:
out = nn.AdaptiveMaxPool2d((1, 1))(imgs)

In [39]:
out.shape

torch.Size([123, 3, 1, 1])

In [40]:
s[-1].shape

torch.Size([123, 512, 7, 7])

In [41]:
import matplotlib.pyplot as plt

In [42]:
i = imgs[0, :, :, :].permute(1, 2, 0).numpy()

In [44]:
imgs.max()

tensor(1.)