In [1]:
%load_ext autoreload
%autoreload 2

## Dataset

In [8]:
import sys

import torch
from torch.utils.data import DataLoader
from clort.clearn.data.cltracking import ContrastiveLearningTracking

In [3]:
root = '../../../datasets/argoverse-tracking/train1/'
dataset = ContrastiveLearningTracking(root, 
                                      occlusion_thresh = 30., 
                                      central_crop=True, 
                                      img_tr_ww = (0.9, 0.9), 
                                      image_size_threshold=100,
                                      img_reshape = (256, 256),
                                      ids_repeat=200)



In [10]:
dataset.dataset_init(0, 50)

In [5]:
dl = DataLoader(dataset, batch_size=2, shuffle=True)

## Model

In [6]:
from clort.clearn.models import VisualEncoder, PointCloudEncoder, FeatureMixer
from mzLosses.contrastive import SoftNearestNeighbourLoss

In [7]:
vis_model = VisualEncoder()
pcl_model = PointCloudEncoder(10)
feature_mixer = FeatureMixer(vis_size = 512, pcl_size = 60)

In [8]:
use_gpu = False
device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')

In [9]:
vis_model.to(device)
pcl_model.to(device)
feature_mixer.to(device)

FeatureMixer(
  (activation): SELU()
  (v_linear_1): Linear(in_features=512, out_features=128, bias=False)
  (v_bn_1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (v_linear_2): Linear(in_features=128, out_features=64, bias=True)
  (f_linear_1): Linear(in_features=124, out_features=128, bias=True)
  (f_linear_2): Linear(in_features=128, out_features=64, bias=True)
)

In [10]:
criterion = SoftNearestNeighbourLoss()
vis_optim = torch.optim.AdamW(vis_model.parameters())
pcl_optim = torch.optim.AdamW(pcl_model.parameters())
feature_optim = torch.optim.AdamW(feature_mixer.parameters())

In [None]:
for imgs, pcls, track_ids in dl:
    b, n_view, C, H, W = imgs.size()
    _, _, d, n = pcls.size()
    
    imgs, pcls, track_ids = imgs.view(b*n_view, C, H, W), pcls.view(b*n_view, d, n), track_ids.flatten()
    imgs, pcls, track_ids = imgs.to(device), pcls.to(device), track_ids.to(device)
    
    vis_optim.zero_grad()
    pcl_optim.zero_grad()
    feature_optim.zero_grad()
    
    imgs_enc = vis_model(imgs)
    pcls_enc = pcl_model(pcls)
    final_enc = feature_mixer(imgs_enc, pcls_enc)
    
#     print(final_enc.size())
    loss = criterion(final_enc, track_ids)
    print(loss)
    
    loss.backward()
    vis_optim.step()
    pcl_optim.step()
    feature_optim.step()

In [16]:
sys.getsizeof(dataset)

48

In [27]:
import numpy as np

In [29]:
list(np.zeros(4))

[0.0, 0.0, 0.0, 0.0]