<a href="https://colab.research.google.com/github/akshaygopalkr/EM-VLM4AD/blob/main/eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Library Installation

In [None]:
!pip install peft
!pip install pycocoevalcap
!pip install pycocotools

Remember to comment out evaluating SPICe by commenting it out in /usr/local/lib/python3.10/dist-packages/pycocoevalcap/eval.py

In [None]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
import torch
import json
import os

from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
import os
from collections import namedtuple
from tqdm import tqdm as progress_bar
from transformers import T5Tokenizer, T5ForConditionalGeneration
from peft import LoraConfig, get_peft_model, LoftQConfig
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import vit_b_32
import json
import pandas as pd

## Loading Dataset

In [None]:
!unzip -q drive/MyDrive/DriveLM/data.zip

## Dataset Code

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MultiFrameDataset(Dataset):

    def __init__(self, input_file, tokenizer, transform=None):
        with open(input_file) as f:
            self.data = json.load(f)

        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get the question and answer at the idx
        qa, img_path = self.data[idx]
        img_paths = list(img_path.values())

        q_text, a_text = qa['Q'], qa['A']
        q_text = f"Question: {q_text} Answer:"

        # Concatenate images into a single tensor
        imgs = [self.transform(read_image(p).float()).to(device) for p in img_paths]
        imgs = torch.stack(imgs, dim=0)

        return q_text, imgs, a_text, sorted(list(img_path.values()))

    def collate_fn(self, batch):

        q_texts, imgs, a_texts, _ = zip(*batch)
        imgs = torch.stack(list(imgs), dim=0)

        encodings = self.tokenizer(q_texts, padding=True, return_tensors="pt").input_ids.to(device)
        labels = self.tokenizer(a_texts, padding=True, return_tensors='pt').input_ids.to(device)

        return encodings, imgs, labels

    def collate_fn_test(self, batch):

        q_texts, imgs, a_texts, img_paths = zip(*batch)

        imgs = torch.stack(list(imgs), dim=0)
        img_paths = list(img_paths)
        encodings = self.tokenizer(q_texts, padding=True, return_tensors="pt").input_ids.to(device)
        labels = self.tokenizer(a_texts, padding=True, return_tensors='pt').input_ids.to(device)

        return q_texts, encodings, imgs, labels, img_paths

## Model Code

In [None]:
VIT_HIDDEN_STATE = 768
VIT_SEQ_LENGTH = 49

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(
        f"Trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )



class DriveVLMT5(nn.Module):

    def __init__(self, config):

        super().__init__()

        # Make tokenizer and text model
        if config.lm == 'T5-Base':
            self.model = T5ForConditionalGeneration.from_pretrained('google-t5/t5-base')
        else:
            self.model = T5ForConditionalGeneration.from_pretrained('google-t5/t5-large')

            # For quantization
            loftq_config = LoftQConfig(loftq_bits=8)
            # Create LoRA model
            lora_config = LoraConfig(
                r=64,
                lora_alpha=32,
                loftq_config=loftq_config,
                lora_dropout=0.05,
                bias='none',
                target_modules=['q', 'v']
            )
            self.model = get_peft_model(self.model, lora_config)

        hidden_size = self.model.config.d_model

        print('Trainable Parameters for LM model:')
        print_trainable_parameters(self.model)

        # Create instance for multi-view processor
        self.mvp = self.MultiViewProcessor(config.gpa_hidden_size, hidden_size, config.lm, freeze=True)

    class MultiViewProcessor(nn.Module):

        def __init__(self, gpa_hidden_size, hidden_size, lm, freeze=False):

            super().__init__()

            # Use ViT for image embeddings
            self.img_model = vit_b_32(weights='DEFAULT')
            self.lm = lm

            # Modal embedding to distinguish between image and text
            self.modal_embeddings = nn.Embedding(2, hidden_size)
            self.modal_embeddings.weight.data.normal_(mean=0.0, std=0.02)

            # If we are freezing the CLIP embeddings
            if freeze:
                for param in self.img_model.parameters():
                    param.requires_grad = False

            # Set matrices based on MIVC paper
            self.w = nn.Linear(in_features=gpa_hidden_size, out_features=1)
            self.Z = nn.Sequential(
                nn.Linear(in_features=VIT_HIDDEN_STATE * VIT_SEQ_LENGTH, out_features=gpa_hidden_size, bias=False),
                nn.Tanh()
            )
            self.G = nn.Sequential(
                nn.Linear(in_features=VIT_HIDDEN_STATE * VIT_SEQ_LENGTH, out_features=gpa_hidden_size, bias=False),
                nn.Sigmoid()
            )

            if self.lm != 'T5-Base':
              self.img_projection_layer = nn.Linear(in_features=VIT_HIDDEN_STATE, out_features=hidden_size)


        def gpa(self, img_embeddings):

            """"
            Calculates the gated-pooling attention score for the image embeddings
            :param img_embeddings: (6x768) dimensional
            :return single embedding of size (768,)
            """

            # Get weights for gated pooling attention
            gpa_weights = torch.softmax(self.w(self.Z(img_embeddings) * self.G(img_embeddings)), dim=0  )

            # Take a linear combination of all the image embeddings
            fused_embeddings = torch.sum(gpa_weights * img_embeddings, dim=0)

            return fused_embeddings

        def get_img_embedding(self, imgs):

            N = imgs.shape[0]

            # Process into patches (N x 6 x 49 x H)
            merged_embedding = torch.stack([self.img_model._process_input(img) for img in imgs], dim=0)

            # Concatenate the batch class tokens -> (N, 6, 50, H)
            batch_class_tokens = self.img_model.class_token.expand(merged_embedding.shape[1], -1, -1).repeat(N, 1, 1, 1)
            merged_embedding = torch.cat([batch_class_tokens, merged_embedding], dim=2)

            # Add positional embeddings and remove class token -> (N, 6, 49, H)
            merged_embedding += self.img_model.encoder.pos_embedding.repeat(N, 1, 1, 1)
            merged_embedding = merged_embedding[:, :, 1:]

            # Get merged embedding and reshape to 2D embedding -> (N, 1, 49, H)
            merged_embedding = torch.stack([self.gpa(embedding.flatten(start_dim=1)).reshape(VIT_SEQ_LENGTH,
                                            VIT_HIDDEN_STATE) for embedding in merged_embedding], dim=0)

            # Project to VL dimension -> (1, 49, H) (H is 512 for t5-small, 768 for t5-base)
            if self.lm != 'T5-Base':
              merged_embedding = self.img_projection_layer(merged_embedding)

            # Add modal type embedding to merged embedding
            merged_embedding += self.modal_embeddings(
                torch.ones((1, merged_embedding.shape[1]), dtype=torch.int, device=device))

            return merged_embedding

        def forward(self, text_enc, imgs, text_model):

            # Get the image embeddings (N x 1 x 49 x H)
            imgs_embedding = self.get_img_embedding(imgs)

            # Get the text embeddings (N x S x H)
            text_embeddings = text_model.get_input_embeddings()(text_enc)

            # Add modal embeddings to text
            text_embeddings += self.modal_embeddings(torch.zeros((1, text_embeddings.shape[1]), dtype=torch.int,
                                                                 device=device))

            # Concatenate embeddings -> (1 x S x 512)
            merged_embedding = torch.cat([text_embeddings, imgs_embedding], dim=1)

            return merged_embedding

    def forward(self, text_enc, imgs, labels=None):

        # Get the merged embeddings
        merged_embedding = self.mvp(text_enc, imgs, self.model)

        # If training include the labels
        return self.model(inputs_embeds=merged_embedding, labels=labels)

## Hyperparameters

In [None]:
Config = namedtuple('Instance', ['batch_size', 'gpa_hidden_size', 'model_name', 'lm'])

config = Config(
    batch_size = 16,
    gpa_hidden_size = 128,
    model_name = '20240229-205610',
    lm = 'T5-Large'
)

## Evaluation Code

In [None]:
def val_model(dloader):

    model.eval()
    ids_answered = set()
    test_data = []

    with torch.no_grad():
      for idx, (q_texts, encodings, imgs, labels, img_paths) in progress_bar(enumerate(dloader), total=len(dloader)):

          # Get the hidden states (output)
          hidden_states = model(encodings, imgs, labels).logits

          # Perform decoding (e.g., greedy decoding)
          outputs = torch.argmax(hidden_states, dim=-1)

          # Get the text output
          text_outputs = [processor.decode(output, skip_special_tokens=True) for output in outputs]

          if idx % 100 == 0:
            print(q_texts)
            print(text_outputs)

          for image_path, q_text, text_output in zip(img_paths, q_texts, text_outputs):

              img_key = image_path[0]

              # Skip duplicate questions
              if image_id_dict[img_key + ' ' + q_text][0] in ids_answered:
                  continue
              if len(text_output) > config.max_len:
                  continue

              ids_answered.add(image_id_dict[img_key + ' ' + q_text][0])
              test_data.append({'image_id': image_id_dict[img_key + ' ' + q_text][0], 'caption': text_output})

    # Save test output to file
    with open(os.path.join('drive', 'MyDrive', 'DriveLM', 'multi_frame_results', config.model_name, 'predictions.json'), 'w') as f:
        json.dump(test_data, f)


def save_experiment():
    """
    Saves the experiment results to a csv
    :param config: The hyperparameters used
    :param statistics: The accuracies for the training, validation, and test sets
    """

    trial_dict = {}

    # Add metrics to dictionary
    for metric, score in coco_eval.eval.items():
        trial_dict[metric] = [score]

    trial_dict = pd.DataFrame(trial_dict)
    trial_dict.to_csv(os.path.join('drive', 'MyDrive', 'DriveLM', 'multi_frame_results', config.model_name, 'metrics.csv'), index=False, header=True)

# Load processors and models
model = DriveVLMT5(config)
model.to(device)

if config.lm == 'T5-Base':
    processor = T5Tokenizer.from_pretrained('google-t5/t5-base')
else:
    processor = T5Tokenizer.from_pretrained('google-t5/t5-large')

processor.add_tokens('<')

model.load_state_dict(torch.load(os.path.join('drive', 'MyDrive', 'DriveLM', 'multi_frame_results', config.model_name,
                                                          'latest_model.pth')))

# Load dataset and dataloader
test_dset = MultiFrameDataset(
    input_file=os.path.join('data', 'multi_frame',
                            'multi_frame_test.json'),
    tokenizer=processor,
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
    ])
)
test_dloader = DataLoader(test_dset, shuffle=True, batch_size=config.batch_size, drop_last=True, collate_fn=test_dset.collate_fn_test)

# Load in image ids
with open(os.path.join('drive', 'MyDrive', 'DriveLM', 'data', 'multi_frame', 'image_id.json')) as f:
    image_id_dict = json.load(f)

# Get the loss and predictions from the model
val_model(test_dloader)

annotation_file = os.path.join('drive', 'MyDrive', 'DriveLM', 'data', 'multi_frame', 'multi_frame_test_coco.json')
results_file = os.path.join('drive', 'MyDrive', 'DriveLM', 'multi_frame_results', config.model_name, 'predictions.json')

# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)

# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)

# evaluate on a subset of images by setting
# coco_eval.params['image_id'] = coco_result.getImgIds()
# please remove this line when evaluating the full validation set
coco_eval.params['image_id'] = coco_result.getImgIds()

# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()

# Save the experiment results
save_experiment()

In [None]:
from google.colab import runtime
runtime.unassign()