# VQA

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import CLIPVisionModel

import torch
import numpy as np
import pandas as pd
import PIL.Image

from torch.utils.tensorboard import SummaryWriter
from torch.optim.sgd import SGD
from torch.utils.data import Dataset
import torch.nn as nn
import typing as tp

import wandb

import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2

In [None]:
writer = SummaryWriter()

In [None]:
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
LLM_NAME: str = 'gpt2'
VIT_NAME: str = 'openai/clip-vit-large-patch14'

VIT_EMB_SIZE = 1024
LLM_INP_EMB_SIZE = 768
LLM_OUT_EMB_SIZE = 768
RET_EMB_SIZE = 600

DATA_PATH: str = "./"
BATCH_SIZE = 1

RET_TOKEN = "[RET]"
PAD_TOKEN = "<|endoftext|>"
IMG_CLS_TOKEN = "<|image|>"
EOT_TOKEN = "<|endoftext|>"

SESSION_SIZE = 1000

## Dataset for image captioning task

In [None]:
class CLEVRCaptioningDataset(Dataset):
    def __init__(
        self,
        tokenizer: GPT2Tokenizer,
        data_type: str = "train",
        transform: tp.Optional[albu.Compose | None] = None
) -> None:
        super().__init__()
        self.data_type = data_type
        self.data = pd.read_csv(DATA_PATH+data_type+"_annotation_dataframe.csv")
        self.transform = transform
        self.tokenizer = tokenizer
    
    def get_batch(self, batch_size: int) -> tp.Tuple:
        rand_idxes = np.random.randint(0, len(self.data), batch_size)
        
        img_paths = self.data.loc[rand_idxes, "Path"].to_list()
        annotations = self.data.loc[rand_idxes, "Annotation"].to_list()

        imgs = torch.cat([
            self._get_image(img_path)[None, ...] for img_path in img_paths
        ])
        annotations_tokens = self.tokenizer(annotations, return_tensors="pt", padding=True)
        
        return (imgs, annotations_tokens)
    
    def _get_image(self, image_path: str) -> torch.Tensor:
        img = np.array(PIL.Image.open(image_path).convert("RGB"))
        if self.transform is None:
            raise ValueError("Transformation must be at least ToTensor, but None recieved")
        img = self.transform(image=img)["image"].float()
        return img
    
    def __getitem__(self, index) -> tp.Tuple:
        img_path, annotation = self.data.iloc[index]
        
        img = self._get_image(img_path)
        annotation_tokens = self.tokenizer(annotation, return_tensors="pt")
    
        return (
            img,
            annotation_tokens
        )
        
    def __len__(self):
        return len(self.data)

## Dataset for visual question answering task

In [None]:
class CLEVRVQADataset(Dataset):
    def __init__(
        self,
        tokenizer: GPT2Tokenizer,
        data_type: str = "train",
        transform: tp.Optional[albu.Compose | None] = None
) -> None:
        super().__init__()
        self.data_type = data_type
        self.data = pd.read_csv(DATA_PATH+data_type+"_dataframe.csv")
        self.transform = transform
        self.tokenizer = tokenizer
        
        self.max_question_length = max([
            len(ques) for ques in self.data["Question"]
        ])
        self.max_answer_length = max([
            len(ans) for ans in self.data["Answer"]
        ])
    
    def get_batch(self, batch_size: int) -> tp.Tuple:
        rand_idxes = np.random.randint(0, len(self.data), batch_size)
        
        img_paths = self.data.loc[rand_idxes, "Path"].to_list()
        questions = self.data.loc[rand_idxes, "Question"].to_list()
        answers = self.data.loc[rand_idxes, "Answer"].to_list()
        promt = [q + EOT_TOKEN + a for q, a in zip(questions, answers)]
    
        
        imgs = torch.cat([
            self._get_image(img_path)[None, ...] for img_path in img_paths
        ])
        
        questions_tokens = self.tokenizer(questions, return_tensors="pt", padding=True)
        answers_tokens = self.tokenizer(answers, return_tensors="pt", padding=True)
        promt_tokens = self.tokenizer(promt, return_tensors="pt", padding=True)
        
        return (imgs, promt_tokens, questions_tokens, answers_tokens)
    
    def _get_image(self, image_path: str) -> torch.Tensor:
        img = np.array(PIL.Image.open(image_path).convert("RGB"))
        if self.transform is None:
            raise ValueError("Transformation must be at least ToTensor, but None recieved")
        img = self.transform(image=img)["image"].float()
        return img
    
    def __getitem__(self, index) -> tp.Tuple:
        img_path, question, answer = self.data.iloc[index]
        
        img = self._get_image(img_path)
        question_tokens = self.tokenizer(question, return_tensors="pt")
        answer_tokens = self.tokenizer(answer, return_tensors="pt")
    
        return (
            img,
            question_tokens,
            answer_tokens
        )
        
    def __len__(self):
        return len(self.data)

## Model achitecture proposed in paper

In [None]:
class MegaModel(nn.Module):
    def __init__(
        self,
        n_visual_tokens: int = 5
    ) -> None:
        super().__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained(LLM_NAME)
        self.llm = GPT2LMHeadModel.from_pretrained(LLM_NAME).to(DEVICE)
        self.vit = CLIPVisionModel.from_pretrained(VIT_NAME).to(DEVICE)
        self._add_special_tokens()
        self._freeze_llm_vit()
        
        self.ret_token_idx = self.tokenizer.encode(RET_TOKEN)[0]
        self.pad_token_idx = self.tokenizer.encode(PAD_TOKEN)[0]
        
        self.n_visual_tokens = n_visual_tokens
        
        self.vit2token = nn.Linear(VIT_EMB_SIZE, LLM_INP_EMB_SIZE * n_visual_tokens).to(DEVICE)
        self.vit2retsapce = nn.Linear(VIT_EMB_SIZE, RET_EMB_SIZE).to(DEVICE)
        self.llm2retspace = nn.Linear(LLM_OUT_EMB_SIZE, RET_EMB_SIZE).to(DEVICE)
        self.input_embeddings = self.llm.get_input_embeddings()
        
    def get_tokenizer(self):
        return self.tokenizer
    
    def decode_from_logits(self, logits):
        return self.tokenizer.batch_decode(torch.argmax(logits, dim=2).to("cpu"))
    
    def zero_grad_token_embeddings(self):
        for param in self.llm.transformer.wte.parameters():
            mask = torch.arange(param.grad.shape[0]) != self.ret_token_idx
            param.grad[mask, :] = 0
    
    def zero_grad_llm_vit(self) -> None:
        self.llm.zero_grad()
        self.vit.zero_grad()
    
    def generate(
        self,
        images,
        question    
    ):
        vit_embeddings = self.vit(images)["pooler_output"]
        promt_embeddings = self.input_embeddings(question["input_ids"])
        
        vit_tokens_embeddings = self.vit2token(
            vit_embeddings
        ).view(-1, self.n_visual_tokens, LLM_INP_EMB_SIZE)
        
        promt_embeddings = torch.cat([
            vit_tokens_embeddings,
            promt_embeddings
        ], dim=1)
        
        llm_output = self.llm.generate(
            inputs_embeds=promt_embeddings,
            output_hidden_states=True,
            pad_token_id=self.pad_token_idx,
            max_length=len(promt_embeddings) + 10,
        )
        
        return llm_output
    
    def forward(
        self,
        images,
        text,
        model_mode: str = "captioning"
    ):

        vit_embeddings = self.vit(images)["pooler_output"]  
        promt_embeddings = self.input_embeddings(text["input_ids"])
        
        labels = text["input_ids"]
        attention_mask = text["attention_mask"]
        
        if model_mode == "captioning":
            vit_tokens_embeddings = self.vit2token(
                vit_embeddings
            ).view(-1, self.n_visual_tokens, LLM_INP_EMB_SIZE)
            
            promt_embeddings = torch.cat([
                vit_tokens_embeddings,
                promt_embeddings
            ], dim=1)
            
            img_attention = torch.ones(
                (attention_mask.size()[0], self.n_visual_tokens),
                dtype=torch.long
            ).to(DEVICE)
            images_labels = (torch.zeros_like(img_attention) - 100).to(DEVICE)
            
            attention_mask = torch.cat([img_attention, attention_mask], dim=1)
            labels = torch.cat([images_labels, labels], dim=1)

        llm_output = self.llm(
            inputs_embeds=promt_embeddings,
            labels=labels,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        return llm_output
            
    def _add_special_tokens(self) -> None: 
        self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def _freeze_llm_vit(self):
        self.llm.eval()
        self.vit.eval()

## Initializing 

In [None]:
my_mega_model = MegaModel()

In [None]:
transform = albu.Compose([
    albu.Resize(
        height=224, 
        width=224
    ),
    ToTensorV2()
])

train_dataset = CLEVRVQADataset(
    tokenizer=my_mega_model.tokenizer,
    data_type="train",
    transform=transform
)

train_annotation_dataset = CLEVRCaptioningDataset(
    tokenizer=my_mega_model.tokenizer,
    data_type="train",
    transform=transform
)

In [None]:
lr = 1e-3
opt = SGD(my_mega_model.parameters(), lr=lr)
lambda1 = lambda epoch: 0.8 ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)

## Image captioning training

In [None]:
wandb.init(
    project="VQA",
    name="Image captioning task",    
    config={
    "architecture": "Frozen transformers",
    "dataset": "CLEVR_v1.0",
    "session size": SESSION_SIZE,
    "batch": BATCH_SIZE,
    "initial learning rate": lr
    }
)

predicted_annotations = []
for iter_num in range(SESSION_SIZE):
    imgs, annotation = train_annotation_dataset.get_batch(BATCH_SIZE)

    annotation["attention_mask"] = torch.zeros_like(annotation["attention_mask"])
    my_mega_model.zero_grad()
    
    output = my_mega_model(
        images=imgs.to(DEVICE),
        text=annotation.to(DEVICE),
        model_mode="captioning"
    )
    
    output.loss.backward()
    my_mega_model.zero_grad_llm_vit()
    opt.step()
    
    if iter_num % 10 == 0:
        predicted_annotations += [(
            my_mega_model.decode_from_logits(output.logits.detach()),
            my_mega_model.tokenizer.batch_decode(annotation["input_ids"])
        )]
    
    if (iter_num + 1) % (SESSION_SIZE / 10) == 0:
        scheduler.step()

    loss_value = output.loss.detach().cpu().item()
    writer.add_scalar("Loss", loss_value, iter_num)
    wandb.log({"loss": loss_value})

wandb.finish()

In [None]:
print(predicted_annotations[-1][0][0])
print(predicted_annotations[-1][1][0])

In [None]:
imgs, promt, question, answer = train_dataset.get_batch(BATCH_SIZE)

generated_tokens = my_mega_model(
    images=imgs.to(DEVICE),
    text=question.to(DEVICE),
    model_mode="captioning"
)

generated_output = my_mega_model.generate(
    images=imgs.to(DEVICE),
    question=question.to(DEVICE)
)

print("Question:             ", my_mega_model.tokenizer.batch_decode(question["input_ids"])[0])
print("Answer:               ", my_mega_model.tokenizer.batch_decode(answer["input_ids"])[0])
print("Model decoded answer: ", my_mega_model.decode_from_logits(generated_tokens["logits"]))
print("Model generated text: ", my_mega_model.tokenizer.batch_decode(generated_output))

## VQA training

In [None]:
wandb.init(
    project="VQA",
    name="Question answering task",    
    config={
    "architecture": "Frozen transformers",
    "dataset": "CLEVR_v1.0",
    "session size": SESSION_SIZE,
    "batch": BATCH_SIZE,
    "initial learning rate": lr
    }
)


predicted_answers = []
for iter_num in range(SESSION_SIZE):
    imgs, promt, question, answer = train_dataset.get_batch(1)
    
    my_mega_model.zero_grad()
    
    output = my_mega_model(
        images=imgs.to(DEVICE),
        text=promt.to(DEVICE),
        model_mode="captioning"
    )
    
    output.loss.backward()
    my_mega_model.zero_grad_llm_vit()
    opt.step()
    
    if (iter_num + 1) % 100 == 0:
        predicted_answers += [(
            my_mega_model.decode_from_logits(output.logits.detach()),
            my_mega_model.tokenizer.batch_decode(promt["input_ids"])
        )]
    
    if (iter_num + 1) % (SESSION_SIZE / 10) == 0:
        scheduler.step()
        
    loss_value = output.loss.detach().cpu().item()
    writer.add_scalar("Loss", loss_value, iter_num)
    wandb.log({"loss": loss_value})
    
wandb.finish()
      

In [None]:
print(predicted_answers[-1][0][0])
print(predicted_answers[-1][1][0])

In [None]:
imgs, promt, question, answer = train_dataset.get_batch(BATCH_SIZE)

generated_tokens = my_mega_model(
    images=imgs.to(DEVICE),
    text=question.to(DEVICE),
    model_mode="captioning"
)

generated_output = my_mega_model.generate(
    images=imgs.to(DEVICE),
    question=question.to(DEVICE)
)

print("Question:             ", my_mega_model.tokenizer.batch_decode(question["input_ids"])[0])
print("Answer:               ", my_mega_model.tokenizer.batch_decode(answer["input_ids"])[0])
print("Model decoded answer: ", my_mega_model.decode_from_logits(generated_tokens["logits"]))
print("Model generated text: ", my_mega_model.tokenizer.batch_decode(generated_output))