In [2]:
from transformers import BertModel, BertTokenizer, ViTImageProcessor, ViTModel
import torch
from torchinfo import summary
from torch import nn
from torch.nn import Transformer, TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
class TextTokenizer(torch.nn.Module):
    def __init__(
        self,
        text_tokenizer=BertTokenizer,
        max_length=25  # Add a max_length parameter
    ):
        super().__init__()
        self.text_tokenizer = text_tokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length  # Store the max_length

    def forward(self, input_question, padding='max_length', truncation=True):
        tokens = self.text_tokenizer(input_question, return_tensors='pt', 
                                     padding=padding, truncation=truncation, 
                                     max_length=self.max_length).to(device)  # Use max_length

        return tokens

class ImageProcessor(torch.nn.Module):
    def __init__(
        self,
        image_model_processor=ViTImageProcessor
    ):

        super().__init__()
        self.image_model_processor = image_model_processor.from_pretrained('google/vit-base-patch16-224-in21k')

    def forward(self, image):
        image = self.image_model_processor(image, return_tensors='pt').to(device)

        return image

class TextEmbedding(torch.nn.Module):
    def __init__(
        self,
        text_model=BertModel,
    ):
        super().__init__()
        self.text_model = text_model.from_pretrained('bert-base-uncased').to(device)


    def forward(self, tokens):
        text_output = self.text_model(input_ids=tokens.input_ids, attention_mask=tokens.attention_mask)
        text_output = text_output.last_hidden_state     # CLS token from the last layer

        return text_output


class ImageEmbedding(torch.nn.Module):
    def __init__(
            self, 
            image_model=ViTModel
        ):
        
        super().__init__()
        self.image_model = image_model.from_pretrained('google/vit-base-patch16-224-in21k').to(device)


    def forward(self, image):
        image_output = self.image_model(pixel_values=image.pixel_values).last_hidden_state

        return image_output

In [5]:
class VQAModel(nn.Module):
    def __init__(
        self,
        dim_model = 768,      # image and text embeddings concatenated
        nhead = 12,                    # No. of Attention heads
        num_layers = 1,               # No. of encoder layers
        num_classes = 8000
    ):
        super().__init__()
        self.text_embedder = TextEmbedding()
        self.image_embedder = ImageEmbedding()

        encoder_layers = TransformerEncoderLayer(d_model=dim_model, nhead=nhead)
        self.transformerEncoder = TransformerEncoder(encoder_layer=encoder_layers, num_layers=num_layers).to(device)

        self.classifier = nn.Linear(dim_model, num_classes).to(device)
        self.softmax = nn.Softmax(dim=1)

        # self.target_transform = nn.Linear(768, dim_model).to(device)
        # decoder_layers = TransformerDecoderLayer(dim_model, nhead)
        # self.transformerDecoder = TransformerDecoder(decoder_layers, num_layers).to(device)


    def forward(self, questions, images):
        question_embedding = self.text_embedder(questions)
        image_embedding = self.image_embedder(images)

        embeddings = torch.cat((question_embedding, image_embedding), dim=1)
        embeddings = embeddings.permute(1, 0, 2)  # (seq, batch, feature)
        output = self.transformerEncoder(embeddings)

        cls_output = output[0, :, :]

        logits = self.classifier(cls_output)
        output = self.softmax(logits)

        return output


In [6]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceM4/VQAv2", split="train")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/352 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/7.24M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.49M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.97M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.5G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.65G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.3G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating testdev split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [24]:
import pandas as pd

train_dataset = dataset['train']

train_df = pd.read_csv('/kaggle/input/vqdata/vqa_train_dataset.csv')
val_df = pd.read_csv('/kaggle/input/vqdata/vqa_val_dataset.csv')

train_df = train_df[~train_df['answers'].isna()]
val_df = val_df[~val_df['answers'].isna()]

In [25]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor

class VQADataset(Dataset):
    def __init__(self, dataframe, image_dataset):
        self.dataframe = dataframe
        self.image_dataset = image_dataset
        self.text_tokenizer = TextTokenizer()
        self.image_processor = ImageProcessor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        ind = int(row['index'])
        image = self.image_dataset[ind]['image']
        question = row['question']
        answer = row['answers']
        
        # sanity check        
        assert self.image_dataset[ind]['question'] == question, "Mismatching training and Image data"

        
        image = image.convert('RGB')
        
#         Preprocessing done in ImageEmbedder
#         preprocess = Compose([
#             Resize((224, 224)),
#             CenterCrop(224),
#             ToTensor(),
#         ])
#         image = preprocess(image)

        # Tokenize question
        tokens = self.text_tokenizer(question, padding='max_length', truncation=True)
        tokens.input_ids = tokens.input_ids.squeeze()
        tokens.attention_mask = tokens.attention_mask.squeeze()
        image = self.image_processor(image)
        return {
            'image': image,
            'questions': tokens,
            'answer': answer
        }

batch_size = 64
    
# Assuming you have separate dataframes for training and validation
val_data = VQADataset(val_df, train_dataset)

# DataLoader for training and validation
val_dataloader = DataLoader(val_data, batch_size=batch_size)

In [26]:
import pickle

with open('/kaggle/input/vqdata/answers_dictionaries.pkl', 'rb') as f:
    data = pickle.load(f)
    id_to_answer = data['id_to_answer']
    answer_to_id = data['answer_to_id']

assert len(answer_to_id) == len(id_to_answer)
val_model = VQAModel(num_classes=len(answer_to_id))



In [27]:
# from torch.utils.tensorboard import SummaryWriter

CHECKPOINT_FILE = '/kaggle/input/lora32/lora_32_final.pth'

val_model = VQAModel(num_classes=len(answer_to_id))
state_dict = torch.load(CHECKPOINT_FILE)

# val_writer = SummaryWriter('/runs/baseline_val')

val_model.load_state_dict(state_dict)

<All keys matched successfully>

In [28]:
from torchmetrics.classification import Precision, Recall, Accuracy, F1Score, AUROC
precision_metric = Precision(task="multiclass", num_classes=len(answer_to_id)).to(device)
recall_metric = Recall(task="multiclass", num_classes=len(answer_to_id)).to(device)
accuracy_metric = Accuracy(task="multiclass", num_classes=len(answer_to_id)).to(device)
f1_metric = F1Score(task="multiclass", num_classes=len(answer_to_id)).to(device)
# auroc_metric = AUROC(task="multiclass", num_classes=len(answer_to_id)).to(device)

def evaluate(preds, true):
    p = precision_metric(preds, true)
    r = recall_metric(preds, true)
    a = accuracy_metric(preds, true)
    f = f1_metric(preds, true)
#     am = auroc_metric(preds, true)
    
    return {
        "precision": p,
        "recall": r,
        "accuracy": a,
        "f1": f,
#         "auroc": auroc_metric
    }


In [29]:
import torch.nn.functional as F
import time

batch_no = 0
iter_val = 0
total_loss = 0
val_model.eval()
answer_list = torch.empty(0).to(device)
preds_list = torch.empty(0).to(device)

with torch.no_grad():
    start_time = time.time()
    for batch in val_dataloader:
        images = batch['image']
        questions = batch['questions']
        answers = batch['answer']

        questions.input_ids = questions.input_ids.squeeze()
        questions.attention_mask = questions.attention_mask.squeeze()
        images.pixel_values = images.pixel_values.squeeze()
        
        # Forward pass
        outputs = val_model(questions, images)

        
        answers_ids = [min(len(answer_to_id) - 1, answer_to_id[key]) for key in answers]
        answers_ids_tensor = torch.tensor(answers_ids)
        answers = F.one_hot(answers_ids_tensor, num_classes=len(answer_to_id)).to(outputs.dtype)
        answers = answers.to(device)
        
#         answers_ids = [answer_to_id[key] for key in answers]
#         answers = torch.zeros(batch_size, len(answer_to_id))
#         for i in range(batch_size):
#             answers[i, answers_ids[i]] = 1
        
        # Sanity check
        assert answers.shape == outputs.shape, "Target and Predicted shapes don't match"
    
        # Compute the loss
#         loss = criterion(outputs, answers)

        
        _, preds = torch.max(outputs, 1)
        _, answer_inds = torch.max(answers, 1)
        
        assert preds.shape == answer_inds.shape, f"Preds_shape: {preds.shape}, Answers_shape: {answer_inds.shape}"
        eval_met = evaluate(preds, answer_inds)
        
        iter_val = batch_no
        
        answer_list = torch.cat((answer_list, answer_inds), dim=0)
        preds_list = torch.cat((preds_list, preds), dim=0)
        
        
#         val_writer.add_scalar('Validation Loss', loss.item(), iter_val)
#         val_writer.add_pr_curve('PR Curve', answers, outputs, iter_val)
#         val_writer.add_scalar('Accuracy', eval_met['accuracy'], iter_val)
#         val_writer.add_scalar('Precision', eval_met['precision'], iter_val)
#         val_writer.add_scalar('Recall', eval_met['recall'], iter_val)
#         val_writer.add_scalar('F1', eval_met['f1'], iter_val)
        
            
        batch_no += 1
        print(f"Batch -> {batch_no} done -> Accu: {eval_met['accuracy']:.2f} -> Prec: {eval_met['precision']:.2f}\r", end="")
#         avg_accuracy += eval_met['accuracy']
#         total_loss += loss.item()
    end_time = time.time()
    print(f'Time taken for prediction: {end_time - start_time} seconds')
# # Compute the average loss and the accuracy
# avg_loss = total_loss / len(val_dataloader)
# accuracy = avg_accuracy / len(val_dataloader)



Time taken for prediction: 674.1015012264252 seconds


In [31]:
answer_list.shape, preds_list.shape

(torch.Size([29994]), torch.Size([29994]))

In [32]:
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

In [33]:
ans_np = answer_list.cpu().numpy()
preds_np = preds_list.cpu().numpy()

f1 = f1_score(ans_np, preds_np, average='macro')
precision = precision_score(ans_np, preds_np, average='macro')
recall = recall_score(ans_np, preds_np, average='macro')
accuracy = accuracy_score(ans_np, preds_np)

print(f'F1 Score for {CHECKPOINT_FILE}: {f1}')
print(f'Precision for {CHECKPOINT_FILE}: {precision}')
print(f'Recall for {CHECKPOINT_FILE}: {recall}')
print(f'Accuracy for {CHECKPOINT_FILE}: {accuracy}')

F1 Score for /kaggle/input/lora32/lora_32_final.pth: 9.908665713403246e-05
Precision for /kaggle/input/lora32/lora_32_final.pth: 5.927558474881457e-05
Recall for /kaggle/input/lora32/lora_32_final.pth: 0.00030175015087507544
Accuracy for /kaggle/input/lora32/lora_32_final.pth: 0.1964392878575715


  _warn_prf(average, modifier, msg_start, len(result))
