In [2]:
# Install the necessary packages
!pip install transformers datasets scikit-learn peft -q

In [3]:
import torch
from transformers import AutoTokenizer, AutoImageProcessor, AutoModel, ViTModel
from PIL import Image
from IPython.display import display
from torch import nn
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from peft import LoraConfig, get_peft_model

In [4]:
# Define the VQA model class (ensure this matches your training definition)
class VQAModel(nn.Module):
    def __init__(self, text_model_name="bert-base-cased", image_model_name="google/vit-base-patch16-224", num_answers=162496):
        super(VQAModel, self).__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        self.image_model = ViTModel.from_pretrained(image_model_name)
        self.text_fc = nn.Linear(self.text_model.config.hidden_size, 512)
        self.image_fc = nn.Linear(self.image_model.config.hidden_size, 512)
        self.classifier = nn.Linear(1024, num_answers)

    def forward(self, text_inputs, image_inputs):
        text_outputs = self.text_model(**text_inputs).last_hidden_state[:, 0, :]  # CLS token
        image_outputs = self.image_model(**image_inputs).last_hidden_state[:, 0, :]  # CLS token
        text_features = self.text_fc(text_outputs)
        image_features = self.image_fc(image_outputs)
        combined_features = torch.cat((text_features, image_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

In [5]:
# Load the trained model
def load_model(model_path, num_answers, device):
    # Initialize the model architecture
    model = VQAModel(num_answers=num_answers)
    model = model.to(device)
    
    # LoRA Configuration
    lora_config = LoraConfig(r=16, lora_alpha=16, target_modules=["query", "value"], lora_dropout=0.1, bias="none")
    lora_model = get_peft_model(model, lora_config)
    
    # Load the saved state dictionary
    lora_model.load_state_dict(torch.load(model_path, map_location=device))
    lora_model.eval()
    
    return lora_model

In [6]:
# Load tokenizer and image processor
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

In [7]:
# Load the dataset
dataset = load_dataset("HuggingFaceM4/VQAv2")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/352 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/7.24M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.49M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.97M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.5G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.65G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.3G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating testdev split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [9]:
val_dataset = dataset['validation']

In [10]:
# Load sample data
def load_sample_data(example):
    image = example['image']
    question = example['question']
    answers = [ans['answer'] for ans in example['answers']]
    return image, question, answers

In [15]:
# Perform inference
def predict_answer(model, image, question, tokenizer, image_processor, device, answer_list):
    text_inputs = tokenizer(question, padding='max_length', truncation=True, return_tensors="pt")
    image_inputs = image_processor(images=[image], return_tensors="pt")
    text_inputs = {k: v.squeeze(0).to(device) for k, v in text_inputs.items()}
    image_inputs = {k: v.squeeze(0).to(device) for k, v in image_inputs.items()}
    
    # Debugging print statements
    print(f"Text inputs: {text_inputs}")
    print(f"Image inputs: {image_inputs}")
    
    with torch.no_grad():
        logits = model(text_inputs, image_inputs)
    predicted_answer_index = logits.argmax().item()
    
    # Decode predicted answer index
    predicted_answer = answer_list[predicted_answer_index] if predicted_answer_index < len(answer_list) else "Unknown"
    
    return predicted_answer

In [16]:
# Load the model
model_path = '/kaggle/input/lora-5/lora_vqa_model_5percent.pth'  # Ensure this path is correct
num_answers = 162496  # Update this based on your training configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(model_path, num_answers, device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
# Evaluate the model on the validation dataset
all_labels = []
all_predictions = []

for example in val_dataset:
    image, question, answers = load_sample_data(example)
    predicted_answer = predict_answer(model, image, question, bert_tokenizer, image_processor, device, answers)
    
    # Use the first answer as the ground truth label for evaluation
    all_labels.append(answers[0])
    all_predictions.append(predicted_answer)

# Calculate evaluation metrics
accuracy = accuracy_score(all_labels, all_predictions)
f1 = f1_score(all_labels, all_predictions, average='weighted')
precision = precision_score(all_labels, all_predictions, average='weighted')
recall = recall_score(all_labels, all_predictions, average='weighted')

print(f"Accuracy: {accuracy:.7f}")
print(f"F1 Score: {f1:.7f}")
print(f"Precision: {precision:.7f}")
print(f"Recall: {recall:.7f}")

Accuracy: 0.1939423
F1 Score: 0.0971998
Precision: 0.0746676
Recall: 0.1939239
