In [1]:
import json
with open("./data/AIC23_Track2_NL_Retrieval/data/train_nl_extracted_.json", "r") as f:
    data_train = json.load(f)

In [2]:
import pandas as pd

data_df = pd.DataFrame(data_train).transpose().reset_index()
data_df = data_df.rename(columns={'index': 'uuid'})
data_df['colors'] = data_df['colors'].replace(['white', 'black', 'gray', 'red', 'blue', 'green', 'brown', 'yellow'], range(8))
data_df['type'] = data_df['type'].replace(['sedan', 'truck', 'suv', 'van', 'bus', 'hatchback'], range(6))
data_df['motion'] = data_df['motion'].replace(['straight', 'left', 'right', 'stop'], range(4))

In [3]:
data_df

Unnamed: 0,uuid,frames,boxes,nl,nl_other_views,type,motion,colors
0,b06c903c-a25d-45fe-b0d5-294f72e34023,"[./validation/S02/c006/img1/000001.jpg, ./vali...","[[539, 606, 273, 277], [532, 631, 271, 282], [...","[A red sedan drives forward., A red midsize se...","[A red sedan keeping straight., A red sedan ru...",0,0,3
1,3a02a86d-154b-4ee2-bd5e-a0811113a9d8,"[./validation/S02/c007/img1/000001.jpg, ./vali...","[[1292, 359, 403, 161], [1230, 360, 406, 160],...","[A red sedan keeping straight., A red sedan dr...",[A red sedan runs down the street followed by ...,0,0,3
2,eb184f7c-35c7-4f3e-af52-1db57fb087d4,"[./validation/S02/c009/img1/000015.jpg, ./vali...","[[1713, 410, 205, 133], [1683, 404, 218, 133],...",[A red sedan runs down the street followed by ...,"[A red midsize sedan keep straight., A red car...",0,0,3
3,abd32535-8acb-49c5-8da4-9e904d263d8c,"[./validation/S02/c006/img1/000001.jpg, ./vali...","[[374, 373, 249, 219], [372, 383, 238, 219], [...","[A white pickup goes straight., A white truck ...","[A pickup truck is going straight., White pick...",1,0,0
4,3819f85b-103a-4f0b-b0f1-49e6b2fedb68,"[./validation/S02/c007/img1/000011.jpg, ./vali...","[[1452, 294, 431, 176], [1367, 294, 454, 178],...","[A white truck runs down the street., A pickup...",[White dodge ram pickup truck going straight t...,1,0,0
...,...,...,...,...,...,...,...,...
2150,ab54d5c1-b712-4baf-9e8b-af4ff8808141,"[./validation/S05/c036/img1/002380.jpg, ./vali...","[[1622, 294, 298, 284], [1619, 281, 301, 284],...",[A black sedan switch lane to right and follow...,[A black sedan making a right turn at intersec...,0,0,1
2151,23896dd5-a992-4587-b1e9-acfcf864fdb1,"[./validation/S05/c035/img1/002965.jpg, ./vali...","[[1691, 504, 228, 237], [1686, 501, 232, 241],...","[A gray sedan stopped at the intersection., A ...",[],0,3,2
2152,2a5e81fb-27d3-4d52-8c4c-e97be079b266,"[./validation/S05/c036/img1/003076.jpg, ./vali...","[[1423, 480, 497, 417], [1411, 464, 509, 428],...","[A gray sedan stopped at the intersection., A ...",[],0,3,2
2153,a3b55866-1ec7-42a0-bbac-7c6f98ec14bc,"[./validation/S05/c035/img1/003460.jpg, ./vali...","[[1124, 0, 109, 84], [1121, 0, 111, 83], [1118...","[A gray SUV stops at the intersection., A gray...",[A gray hatchback going straight down the stre...,2,2,2


In [124]:
import torch
from torch.utils.data import Dataset
import numpy as np
from dataloader.utils import get_motion_img
import os
import torchvision.transforms as transforms

class Track2CustomDataset(Dataset):
    def __init__(self, video_params, data_tracks, tokenizer, max_len, transforms, config, mode="train"):
        
        self.samples = data_tracks
        self.transforms = transforms
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.video_params = video_params
        self.mode = mode
        self.config = config
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample = self.samples.iloc[index]

        if self.mode == "train":

            final, motion, motion_line = self.image_features(sample)
            text_inputs = self.lang_features(sample)
            
            sample = {
                'text': text_inputs,
                'video': final,
                'motion': motion,
                'motion_line': motion_line,
                'color_label': sample['colors'],
                'type_label': sample['type'],
                'motion_label': sample['motion']

            }

            return sample
        
        if self.mode == "infer_text":

            text_inputs = self.lang_features(sample)
            return {'text': text_inputs}
        
        if self.mode == "infer_video":
            final, motion, motion_line = self.image_features(sample)
            return {
                'video': final,
                'motion': motion,
                'motion_line': motion_line
            }    
    
    def image_features(self, sample):
        frames_path, boxes = sample['frames'], sample['boxes']
            
        veh_imgs, motion_line, motion = get_motion_img(self.config['general_config']['data_dir'], frames_path, boxes, self.config['arch']['base_settings']['video_params']['num_frames'])

        if self.transforms:
            veh_imgs = [self.transforms(img.astype(np.float32)) for img in veh_imgs]
            motion_line = self.transforms(motion_line.astype(np.float32))
            motion = self.transforms(motion.astype(np.float32))
    
        veh_imgs = torch.stack(veh_imgs)
        
        final = torch.zeros([self.video_params['num_frames'], 3, self.video_params['input_res'], self.video_params['input_res']])
        final[: veh_imgs.shape[0]] = veh_imgs

        return final, motion, motion_line

    def lang_features(self, sample):
        nl_descriptions = sample['nl']        
        text_inputs = []
        for idx, text in enumerate(nl_descriptions):
            # print("text: ", text, ", idx: ", idx)
            tokenized_inp = self.tokenizer.encode_plus(
                                text,
                                truncation=True,
                                add_special_tokens=True,
                                max_length=self.max_len,
                                padding='max_length'
                            )
            text_inputs.append({
                'input_ids': torch.LongTensor(tokenized_inp['input_ids']),
                'attention_mask': torch.LongTensor(tokenized_inp['input_ids'])
            })        
        
        return text_inputs

def videotext_collate_fn(batch_data):
    print(batch_data[0]['color_label'])
    frames = torch.stack([item['video'] for item in batch_data])
    motion = torch.stack([item['motion'] for item in batch_data])
    motion_line = torch.stack([item['motion_line'] for item in batch_data])
    input_ids = torch.stack([cap['input_ids'] for item in batch_data for cap in item['text']])
    attention_mask = torch.stack([cap['attention_mask'] for item in batch_data for cap in item['text']])
    
    color_label = torch.LongTensor([item['color_label'] for item in batch_data])
    type_label = torch.LongTensor([item['type_label'] for item in batch_data])
    motion_label = torch.LongTensor([item['motion_label'] for item in batch_data])

    return {'video': frames, 'text': {'input_ids': input_ids, 'attention_mask': attention_mask}, 'motion': motion, 'motion_line': motion_line,
            'color_label': color_label, 'type_label': type_label, 'motion_label': motion_label}

def text_collate_fn(batch_data):
    input_ids = torch.stack([cap['input_ids'] for item in batch_data for cap in item['text']])
    attention_mask = torch.stack([cap['attention_mask'] for item in batch_data for cap in item['text']])
    
    return {'text': {'input_ids': input_ids, 'attention_mask': attention_mask}}

def video_collate_fn(batch_data):
    frames = torch.stack([item['video'] for item in batch_data])
    motion = torch.stack([item['motion'] for item in batch_data])
    motion_line = torch.stack([item['motion_line'] for item in batch_data])
    return {'video': frames, 'motion': motion, 'motion_line': motion_line}



def get_transforms(img_size, train, size=1):
    if train:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomResizedCrop(img_size * size, scale=(0.8, 1)),
            transforms.RandomApply([transforms.RandomRotation(10)], p=0.5),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((img_size * size, img_size * size)),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])


def get_train_dataloader(config, df):
    dataset = Track2CustomDataset(data_tracks=df, 
                                  video_params=config.arch.base_settings.video_params,
                                  tokenizer=config.general_config.tokenizer,
                                  max_len=int(config.general_config.max_len),
                                  transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=True),
                                  config=config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.general_config.train_batch_size,
        num_workers=config.general_config.n_workers,
        collate_fn=videotext_collate_fn,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

    return dataloader
    
def get_valid_dataloader(config, df):
    dataset = Track2CustomDataset(data_tracks=df, 
                                  video_params=config.arch.base_settings.video_params,
                                  tokenizer=config.general_config.tokenizer,
                                  max_len=int(config.general_config.max_len),
                                  transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
                                  config=config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.general_config.valid_batch_size,
        num_workers=config.general_config.n_workers,
        collate_fn=videotext_collate_fn,
        shuffle=False,
        pin_memory=True,
        drop_last=False
    )

    return dataloader

def get_infer_dataloader(config, df_video, df_text):
    text_dataset = Track2CustomDataset(data_tracks=df_text,
                                       video_params=config.arch.base_settings.video_params,
                                       tokenizer=config.general_config.tokenizer,
                                       max_len=int(config.general_config.max_len),
                                       transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
                                       config=config,
                                       mode="infer_text")
    
    video_dataset = Track2CustomDataset(data_tracks=df_video,
                                        video_params=config.arch.base_settings.video_params,
                                        tokenizer=config.general_config.tokenizer,
                                        max_len=int(config.general_config.max_len),
                                        transforms=get_transforms(config.arch.base_settings.video_params.input_res, train=False),
                                        config=config,
                                        mode="infer_video")
    
    text_dataloader = torch.utils.data.DataLoader(
        text_dataset,
        batch_size=config.general_config.valid_batch_size,
        num_workers=config.general_config.n_workers,
        collate_fn=text_collate_fn,
        shuffle=False,
        pin_memory=True,
        drop_last=False
    )

    video_dataloader = torch.utils.data.DataLoader(
        video_dataset,
        batch_size=config.general_config.valid_batch_size,
        num_workers=config.general_config.n_workers,
        collate_fn=video_collate_fn,
        shuffle=False,
        pin_memory=True,
        drop_last=False
    )
    return video_dataloader, text_dataloader

In [125]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [126]:
torch.stack([torch.tensor(1) for i in range(10)])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [127]:
from dotwiz import DotWiz
with open('./configs/baseline_config.json', "r") as f:
    config = json.load(f)
print(config)

config = DotWiz(config)

config['general_config']['tokenizer'] = tokenizer


{'general_config': {'data_dir': './data/AIC23_Track2_NL_Retrieval/data', 'max_len': 64, 'train_batch_size': 12, 'valid_batch_size': 12, 'n_workers': 8, 'kfolds': 5, 'gradient_checkpointing': True, 'epochs': 10, 'n_warmup_steps': 0, 'gradient_accumulation_steps': 1, 'unscale': True, 'evaluate_n_times_per_epoch': 1, 'max_grad_norm': 1000, 'train_print_frequency': 20, 'valid_print_frequency': 20, 'loss': 'InfoNCE', 'load_checkpoint': None}, 'optimizer': {'weight_decay': 5e-05, 'learning_rate': 2e-05, 'eps': 1e-08, 'betas': [0.9, 0.999]}, 'scheduler': {'scheduler_type': 'linear_warmup_cosine_annealing_lr', 'batch_scheduler': True, 'constant_schedule_with_warmup': {'n_warmup_steps': 0}, 'linear_schedule_with_warmup': {'n_warmup_steps': 0}, 'cosine_schedule_with_warmup': {'n_cycles': 0.5, 'n_warmup_steps': 0}, 'polynomial_decay_schedule_with_warmup': {'n_warmup_steps': 0, 'power': 1.0, 'min_lr': 0.0}, 'linear_warmup_cosine_annealing_lr': {'warmup_epochs': 0, 'max_epochs': 20}}, 'arch': {'bas

In [128]:
config.general_config.train_batch_size = 2
config.general_config.valid_batch_size = 2

In [129]:
data_df['motion'].value_counts()

0    1524
2     244
1     213
3     174
Name: motion, dtype: int64

In [130]:
train_loader = get_train_dataloader(config, data_df)

In [131]:
batch = next(iter(train_loader))

4
2
1
3
2
0
32

2
2
0
3
1
1
1
2


In [132]:
from loss.crossentropy import CrossEntropyLabelSmooth

criterion = CrossEntropyLabelSmooth()

In [133]:
inputs = torch.rand(4, 6)
label = torch.randint(6, size=(4,))

In [134]:
criterion(inputs, label)

tensor(1.8436)

In [135]:
torch.IntTensor([1])

tensor([1], dtype=torch.int32)