## CLIP-ViP Modified Code

### Dataset and Dataloader
In order for the model to utilize the emotion data that we extracted from the text data, we need to modify the Dataset class that is used to call this newly added data. Below is the modified version of the HDVILAVideoRetrievalDataset class. We mainly update the \_\_getitem\_\_ method to additionally read the emotion data as well. 

In [None]:
import torch
from torch.utils.data import Dataset
import random
import os
import json
from torch.utils.data.dataloader import default_collate
from src.utils.logger import LOGGER
from src.utils.basic_utils import flat_list_of_lists
from src.datasets.data_utils import mask_batch_text_tokens, img_collate
from src.datasets.dataloader import init_transform_dict, init_transform_dict_simple
import decord # video loader 
from decord import VideoReader
from decord import cpu, gpu
decord.bridge.set_bridge("torch")
import math
import torch.nn.functional as F
import numpy as np
import cv2
import lmdb
import glob
import src.utils.stop_words as stop_words
from PIL import Image
from src.datasets.sample_frames import SampleFrames

In [None]:
class HDVILAVideoRetrievalDataset(Dataset):
    """
    datalist
    """

    def __init__(self, cfg, vis_dir, anno_path, vis_format='video', mode="train"):
        assert vis_format in ["video", "frame"]
        self.cfg = cfg
        self.vis_dir = vis_dir
        self.anno_path = anno_path
        self.mode = mode
        self.vis_format = vis_format
        self.n_clips = cfg.train_n_clips if mode == "train" else cfg.test_n_clips
        self.num_frm = cfg.train_num_frms if mode == "train" else cfg.test_num_frms
        self.sample_rate = cfg.sample_rate
        if hasattr(cfg, "text_pos_num"):
            self.pos_num = cfg.pos_num
        else:
            self.pos_num = 1
        self.transform = init_transform_dict_simple(video_res=cfg.video_res,
                                             input_res=cfg.input_res)[mode]
        self.frame_sampler = SampleFrames(clip_len=self.num_frm, 
                                          frame_interval=self.sample_rate, 
                                          num_clips=self.n_clips, 
                                          temporal_jitter=True)
        self.init_dataset_process()


    def init_dataset_process(self):
        json_type = os.path.splitext(self.anno_path)[-1]
        assert json_type in ['.json', '.jsonl']

        if json_type == '.jsonl':
            data = []
            with open(self.anno_path) as f:
                for line in f:
                    data.append(json.loads(line))
        else:
            data = json.load(open(self.anno_path))
        self.datalist = data
        if self.cfg.is_demo:
            self.dir_list = os.listdir(self.vis_dir)

    def id2path(self, id):
        clip_name = id
        if self.vis_format == 'video':
            name = os.path.join(self.vis_dir, clip_name.split('/')[-1]+".mp4")
            if "lsmdc" in self.vis_dir:
                name = os.path.join(self.vis_dir, clip_name + ".avi")
        else:
            name = os.path.join(self.vis_dir, clip_name)
        return name

    def __len__(self):
        if self.cfg.is_demo:
            return len(self.dir_list)
        else:
            return len(self.datalist)

    def get_sample_idx(self, total_frame_num):
        """
        sample rate > 0: use SampleFrames, loop default
        sample rate = 0: uniform sampling, temporal jittering
        """
        if self.sample_rate > 0:
            results = {"total_frames": total_frame_num,
                    "start_index": 0}
            results = self.frame_sampler(results)
            return results["frame_inds"]
        elif self.sample_rate == 0:
            if hasattr(self.cfg, "sample_jitter") and self.cfg.sample_jitter and self.mode == "train":
                interval = int(total_frame_num / (self.n_clips*self.num_frm - 1))
                start = np.random.randint(0, interval+1)
                end = np.random.randint(total_frame_num-1-interval, total_frame_num)
                return np.linspace(start, end, self.n_clips*self.num_frm).astype(int)
            else:
                return np.linspace(0, total_frame_num-1, self.n_clips*self.num_frm).astype(int)

    def load_video(self, vis_path):
        vr = VideoReader(vis_path, ctx=cpu(0))
        total_frame_num = len(vr)

        frame_idx = self.get_sample_idx(total_frame_num)
        img_array = vr.get_batch(frame_idx) # (n_clips*num_frm, H, W, 3)

        img_array = img_array.permute(0, 3, 1, 2).float() / 255.
        img_array = self.transform(img_array)

        return img_array

    def load_frames(self, vis_path, total_frame_num):
        frame_idx = self.get_sample_idx(total_frame_num)

        img_array = []
        for i in frame_idx:
            img = Image.open(os.path.join(vis_path, \
                    vis_path.split('/')[-1] + '_{0:03d}.jpg'.format(i))).convert("RGB")
            img_array.append(np.array(img))
        img_array = torch.from_numpy(np.array(img_array))  # (n_clips*num_frm, H, W, 3)

        img_array = img_array.permute(0, 3, 1, 2).float() / 255.
        img_array = self.transform(img_array)

        return img_array

    # This is where we modify the code to include the emotion data
    def __getitem__(self, index):
        if self.cfg.dummy_data:
            return dict(
            video = torch.randn(self.n_clips*self.num_frm, 3, self.cfg.input_res[0], self.cfg.input_res[1]),  # [clips, num_frm, C, H_crop, W_crop]
            texts = ["This is a dummy sentence, which contains nothing meaningful."]
        )

        if self.cfg.is_demo:
            # Get the list of all files and directories 
            # path = self.vis_dir
            video = self.dir_list[index]
            video_id, _ = os.path.splitext(video)
            vis_id = video_id
            texts = [self.cfg.query]  # for testing
            emotions = self.cfg.emotion
            
        else:
            if not ("video_id" in self.datalist[index].keys()):
                video = self.datalist[index]["video"]
                video_id, _ = os.path.splitext(video)
                vis_id = video_id
                texts = self.datalist[index]['caption']
            else:
                vis_id = self.datalist[index]['video_id']
                texts = self.datalist[index]['caption']

            if isinstance(texts, list):
                texts = random.sample(self.datalist[index]['caption'], self.pos_num)
                if 'didemo' in self.anno_path:
                    texts = [' '.join(self.datalist[index]['caption'])]
            else:
                texts = [texts]

            # We get the emotions from the datalist
            emotions = [self.datalist[index][emotion] for emotion in ["joy", "trust", "surprise", "anticipation", "fear", "sadness", "disgust", "anger"]]
            
        vis_path = self.id2path(vis_id)
        video = self.load_video(vis_path) if self.vis_format=='video' else self.load_frames(vis_path, self.datalist[index]['num_frame'])     

        return dict(
            video = video,  # [clips*num_frm, C, H_crop, W_crop]
            texts = texts,
            emotions = emotions,
            vis_id = vis_id
        )

### Collator for creating batches

A custom collator is used to create the batches of:
* sequences of tokens 
* the attention masks corresponding to each sequence 
* videos
* and now the emotions in each sequence. 

In [None]:
class VideoRetrievalCollator(object):
    def __init__(self, tokenizer, max_length=40, is_train=True):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_train = is_train

    def collate_batch(self, batch):
        if isinstance(batch[0]["video"], torch.Tensor):
            v_collate = default_collate
        else:
            v_collate = img_collate
        video = v_collate([d["video"] for d in batch])

        text_examples = flat_list_of_lists([d["texts"] for d in batch])
        # Add emotion data
        emotions = torch.LongTensor([(d["emotions"]) for d in batch])
        
        # for vis_id collation
        vid_collate = default_collate
        vis_id = vid_collate([d["vis_id"] for d in batch])
        
        text_str_list = [d for d in text_examples]  # (B, )

        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        # Add emotion data to the final returned batch
        collated_batch = dict(
            video=video,   # [B, clips, num_frm, C, H_crop, W_crop]
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            emotions=emotions,
            vis_id=vis_id
        )

        return collated_batch

### Text Embeddings
The following code shows our main contributions and the implementation of our ideas. The CLIP model is complicated and consists of a hierarchy of many classes. Largely, it consists of two transformers, each for learning the video and text data together. These transformers are further divided into smaller components like encoder and embedding classes. We start with the CLIPTextEmbeddings class, which we modify to incorporate the emotion data in the creation of the text embeddings. 

We do this by initializing an embedding for each emotion, resulting in a total of 8 emotions. Here, the dimensions of the embeddings are the same as the token and positional embeddings. For each sequence, which has a set of corresponding emotions, we call the embeddings for each emotion and then average them to create a single aggregated emotion embedding. This embedding is then added to each token embedding in the sequence alongside the positional embeddings. This allows the model to incorporate the emotion information extracted from each sequence (caption) when learning their representations. 

In [None]:
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        embed_dim = config.hidden_size

        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
        self.emotion_embedding = nn.Embedding(8, embed_dim)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        emotions: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        
        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
        batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]

        if emotions is not None:
            # Change non-zero values to 1, effectively binarizing the input
            emotions = torch.where(emotions > 0, torch.ones_like(emotions), torch.zeros_like(emotions))
        
        # Retrieve all emotion embeddings
        all_emotion_embeds = self.emotion_embedding.weight.unsqueeze(0).repeat(batch_size, 1, 1)  # [batch_size, 8, embed_dim]

        if emotions is not None:
            emotion_mask = emotions.unsqueeze(-1).type_as(all_emotion_embeds)  # [batch_size, 8, 1]
            selected_emotion_embeds = all_emotion_embeds * emotion_mask  # [batch_size, 8, embed_dim]
            emotion_embeds = selected_emotion_embeds.sum(1) / (emotion_mask.sum(1) + 1e-8)  # [batch_size, embed_dim]
        else:
            emotion_embeds = torch.zeros(batch_size, self.token_embedding.embedding_dim, device=input_ids.device if input_ids is not None else inputs_embeds.device)

        emotion_embeds = emotion_embeds.unsqueeze(1).expand(-1, seq_length, -1)  # [batch_size, seq_length, embed_dim]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings + emotion_embeds

        return embeddings

### CLIPTextTransformer
Here we update the CLIPTextTransformer class to accept the emotion data 

In [None]:
class CLIPTextTransformer(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        self.embeddings = CLIPTextEmbeddings(config)
        self.encoder = CLIPEncoder(config)
        self.final_layer_norm = nn.LayerNorm(embed_dim)

    @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        emotions: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is None:
            raise ValueError("You have to specify either input_ids")

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, emotions=emotions)

        bsz, seq_len = input_shape
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        if_fp16 = hidden_states.dtype == torch.float16
        causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, fp16=if_fp16).to(hidden_states.device)
        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def _build_causal_attention_mask(self, bsz, seq_len, fp16=False):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        if fp16:
            mask = mask.half()
        return mask