# Vision-Guided Cross-Attention and Late-Fusion Network

In [None]:
import os
os.environ['http_proxy'] = "http://192.41.170.23:3128" 
os.environ['https_proxy'] = "http://192.41.170.23:3128"

### Import Necessary Libraries and Environment Setup

In [None]:
#!pip uninstall -y transformers tokenizers
#!pip install transformers==4.9.0 tokenizers==0.10.1

In [None]:
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset, set_caching_enabled
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from transformers import (
    # Preprocessing / Common
    AutoTokenizer, AutoFeatureExtractor,
    # Text & Image Models (Now, image transformers like ViTModel, DeiTModel, BEiT can also be loaded using AutoModel)
    AutoModel,
    # Training / Evaluation
    TrainingArguments, Trainer,
    # Misc
    logging
)

# import nltk
# nltk.download('wordnet')
from nltk.corpus import wordnet

from sklearn.metrics import accuracy_score, f1_score

In [None]:
# SET CACHE FOR HUGGINGFACE TRANSFORMERS + DATASETS
os.environ['HF_HOME'] = os.path.join(".", "cache")
# SET ONLY 1 GPU DEVICE
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

set_caching_enabled(True)
logging.set_verbosity_error()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
#     print('Memory Usage:')
#     print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
#     print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

### Load the Dataset

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import os
import pandas as pd
from datasets import load_dataset

# Define the path to the train.csv and val.csv files
train_file = "mrad-train.csv"
val_file = "mrad-valid.csv"

# Load the train and val datasets
train_dataset = load_dataset("csv", data_files=train_file, split="train")
val_dataset = load_dataset("csv", data_files=val_file, split="train")

# Extract the answers from the datasets
train_answers = train_dataset["answer"]
val_answers = val_dataset["answer"]

# Build the answer space
answer_space = set(train_answers + val_answers)
answer_space = list(answer_space)

# Map the answers to label indices
def map_to_label(example):
    labels = [
        answer_space.index(answer) if answer in answer_space else -1
        for answer in example["answer"]
    ]
    return {"label": labels}

# Apply the mapping function to the train and val datasets
train_dataset = train_dataset.map(map_to_label, batched=True)
val_dataset = val_dataset.map(map_to_label, batched=True)

# Print the updated datasets
print(train_dataset)
print(val_dataset)


### Check VQA Sample

In [None]:
len(answer_space)

In [None]:
# len(set(train_answers))
# print(sorted(set(train_answers)))

In [None]:
# print(sorted(set(val_answers)))

In [None]:
from PIL import Image
import os
import numpy as np
from IPython.display import display

def preprocess_images(image_ids):
    image_folder = "VQA_RAD Image Folder"
    images = []
    for image_id in image_ids:
        image_path = os.path.join(image_folder, image_id)
        image = Image.open(image_path).convert('RGB')
        images.append(image)
    return images

def showExample(train=True, example_id=None):
    if train:
        dataset = train_dataset
    else:
        dataset = val_dataset
    if example_id is None:
        example_id = np.random.randint(len(dataset))

    image_id = dataset[example_id]["image_name"]
    image = Image.open(os.path.join("VQA_RAD Image Folder", image_id))
    display(image.reduce(2))

    print("Question:\t", dataset[example_id]["question"])
    print("Answer:\t\t", dataset[example_id]["answer"], "(Label: {0})".format(dataset[example_id]["label"]))


In [None]:
showExample(train=True)

### Create a Multimodal Collator

This will be used in the `Trainer()` to automatically create the `Dataloader` from the dataset to pass inputs to the model

The collator will process the **question (text)** & the **image**, and return the **tokenized text** along with the **featurized image**. These will be fed into our multimodal transformer model for VQA.

In [None]:
@dataclass
class MultimodalCollator:
    tokenizer: AutoTokenizer
    preprocessor: AutoFeatureExtractor

    def tokenize_text(self, texts: List[str]):
        encoded_text = self.tokenizer(
            text=texts,
            padding='longest',
            max_length=24,
            truncation=True,
            return_tensors='pt',
            return_token_type_ids=True,
            return_attention_mask=True,
        )
        return {
            "input_ids": encoded_text['input_ids'].squeeze(),
            "token_type_ids": encoded_text['token_type_ids'].squeeze(),
            "attention_mask": encoded_text['attention_mask'].squeeze(),
        }

    def preprocess_images(self, images: List[str]):
        processed_images = self.preprocessor(
            images=[Image.open(os.path.join("VQA_RAD Image Folder", image_id)).convert('RGB') for image_id in images],
            return_tensors="pt",
        )
        return {
            "pixel_values": processed_images['pixel_values'].squeeze(),
        }

    def __call__(self, raw_batch_dict):
        return {
            **self.tokenize_text(
                raw_batch_dict['question']
                if isinstance(raw_batch_dict, dict) else
                [i['question'] for i in raw_batch_dict]
            ),
            **self.preprocess_images(
                raw_batch_dict['image_name']
                if isinstance(raw_batch_dict, dict) else
                [i['image_name'] for i in raw_batch_dict]
            ),
            'labels': torch.tensor(
                raw_batch_dict['label']
                if isinstance(raw_batch_dict, dict) else
                [i['label'] for i in raw_batch_dict],
                dtype=torch.int64
            ),
        }

### Define VG-CALF Model Architecture

Multimodal models can be of various forms to capture information from the text & image modalities, along with some cross-modal interaction as well.
Here, we implemented **"Vision-Guided Cross-Attention"** and **"Fusion"** modules, to enhance inter-modality relationship between the text encoder & image encoder to perform VQA with radiology images.

The text encoder can be a text-based transformer model like BERT, while the image encoder could be an image transformer, such as ViT. After passing the tokenized question through the text-based transformer & the image features through the image transformer, these modules are applied and the outputs are passed through fully-connected linear layers.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoModel
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoModel
from typing import Optional

class MultimodalVQAModel(nn.Module):
    def __init__(
            self,
            num_labels: int = len(answer_space),
            intermediate_dim: int = 768,
            pretrained_text_name: str = 'bert-base-uncased',
            pretrained_image_name: str = 'google/vit-base-patch16-224'):

        super(MultimodalVQAModel, self).__init__()
        self.num_labels = num_labels
        self.pretrained_text_name = pretrained_text_name
        self.pretrained_image_name = pretrained_image_name

        self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
        self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)

        # Vision-Guided Cross-Attention module
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.text_encoder.config.hidden_size,
            num_heads=16,  # Adjust as needed
        )

        # Late-Fusion module
        self.fusion = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, intermediate_dim),  # Add text features
            nn.LayerNorm(intermediate_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
        ) 

        self.classifier = nn.Linear(intermediate_dim, self.num_labels)

        self.criterion = nn.CrossEntropyLoss()

    def forward(
            self,
            input_ids: torch.LongTensor,
            pixel_values: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            token_type_ids: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None):

        encoded_text = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )
        encoded_image = self.image_encoder(
            pixel_values=pixel_values,
            return_dict=True,
        )

        # Apply cross-attention (image attends to text)
        attended_image_features, _ = self.cross_attention(
            query=encoded_image['pooler_output'],
            key=encoded_text['pooler_output'],
            value=encoded_text['pooler_output'],
        )

        # Concatenate attended image features with text features
        fused_features = torch.cat(
            [encoded_text['pooler_output'], attended_image_features],
            dim=1
        )

        # Apply late fusion module
        fused_output = self.fusion(fused_features)

        logits = self.classifier(fused_output)

        out = {
            "logits": logits
        }
        if labels is not None:
            loss = self.criterion(logits, labels)
            out["loss"] = loss

        return out             
                

### Define a Function for the Model and the Collator

We plan to experiment with pretrained text & image encoders for our model. So, we will have to create the corresponding collators along with the model (tokenizers, featurizers & models need to be loaded from same pretrained checkpoints).

In [None]:
def createMultimodalVQACollatorAndModel(text='bert-base-uncased', image='google/vit-base-patch16-224'):
    tokenizer = AutoTokenizer.from_pretrained(text)
    preprocessor = AutoFeatureExtractor.from_pretrained(image)

    multi_collator = MultimodalCollator(
        tokenizer=tokenizer,
        preprocessor=preprocessor,
    )


    multi_model = MultimodalVQAModel(pretrained_text_name=text, pretrained_image_name=image).to(device)
    return multi_collator, multi_model

### Performance Metrics

In [None]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
from typing import Tuple, Dict

def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    logits, labels = eval_tuple
    preds = logits.argmax(axis=-1)
    
    # Convert labels and preds to lists of strings
    reference = [str(label) for label in labels]
    candidate = [str(pred) for pred in preds]

    # Calculate BLEU-1 score
    bleu_score = sentence_bleu([reference], candidate)

    return {
        # "wups": batch_wup_measure(labels, preds),
        "bleu": bleu_score,
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average='macro')
    }

### Model Training and Evaluation

In [None]:
# !pip install transformers==4.8.2

In [None]:
args = TrainingArguments(
    output_dir="check",
    seed=12345,
    evaluation_strategy="steps",
    eval_steps=96,
    logging_strategy="steps",
    logging_steps=96,
    save_strategy="steps",
    save_steps=96,
    save_total_limit=5,             # Save only the last 3 checkpoints at any given time while training
    metric_for_best_model='acc',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    remove_unused_columns=False,
    num_train_epochs=200,
    fp16=True,
    # warmup_ratio=0.01,
    # learning_rate=5e-4,
    # weight_decay=1e-4,
    # gradient_accumulation_steps=2,
    dataloader_num_workers=8,
    load_best_model_at_end=True,
) 

In [None]:
def createAndTrainModel(dataset, args, text_model='bert-base-uncased', image_model='google/vit-base-patch16-224', multimodal_model='med-flamingo/med-flamingo'):
    collator, model = createMultimodalVQACollatorAndModel(text_model, image_model)

    multi_args = deepcopy(args)
    multi_args.output_dir = os.path.join("check")
        
       
    multi_trainer = Trainer(
        model,
        multi_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        data_collator=collator,
        compute_metrics=compute_metrics
    )

    train_multi_metrics = multi_trainer.train()
    eval_multi_metrics = multi_trainer.evaluate()

    # Save the best model at the end of training without specifying the filename
    # multi_trainer.save_model(output_dir=os.path.join("check"), push_to_hub=False)

    return collator, model, train_multi_metrics, eval_multi_metrics

In [None]:
#train_dataset = load_dataset(...)
#test_dataset = load_dataset(...)

# Combine the train and test datasets into a single dictionary
dataset = {'train': train_dataset, 'test': val_dataset}

# Call the createAndTrainModel function with the dataset argument
collator, model, train_multi_metrics, eval_multi_metrics = createAndTrainModel(dataset, args)


In [None]:
# collator, model, train_multi_metrics, eval_multi_metrics = createAndTrainModel(dataset[train_dataset], args)

In [None]:
eval_multi_metrics

### Model Inferencing

In [None]:
model = MultimodalVQAModel()

# model.load_state_dict(torch.load(os.path.join("checkpoints", "checkpoint-19200")))
# model.to(device)

In [None]:
model.eval()
output = model(input_ids, pixel_values, attention_mask, token_type_ids, labels)