#**Easy-VQA**

##RESNET,BERT

* No data augmentation
* 4 layers
* 10 epochs

###Dataset loading and setup

In [None]:
!pip install easy-vqa
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import BertTokenizer, BertModel
from PIL import Image
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, f1_score, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# Import easy-vqa dataset
from easy_vqa import get_train_questions, get_test_questions
from easy_vqa import get_train_image_paths, get_test_image_paths

print('Importing done')
# -----------------------------
# CONFIGURATION
# -----------------------------
# Get easy-vqa data
train_questions, train_answers, train_image_ids = get_train_questions()
test_questions, test_answers, test_image_ids = get_test_questions()
train_image_paths = get_train_image_paths()
test_image_paths = get_test_image_paths()

# Create label map from unique answers
unique_answers = list(set(train_answers))
label_map = {ans: i for i, ans in enumerate(unique_answers)}
num_classes = len(label_map)
print('Configuration done')

# -----------------------------
# DATASET CLASS
# -----------------------------
class VQADataset(Dataset):
    def __init__(self, questions, answers, image_ids, image_paths, label_map, transform, tokenizer, max_len=20):
        self.questions = questions
        self.answers = answers
        self.image_ids = image_ids
        self.image_paths = image_paths
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        answer = self.answers[idx]
        image_id = self.image_ids[idx]
        image_path = self.image_paths[image_id]

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        encoding = self.tokenizer(
            question,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        label = self.label_map[answer]

        return image, encoding["input_ids"].squeeze(0), encoding["attention_mask"].squeeze(0), torch.tensor(label)

# -----------------------------
# FINETUNABLE RESNET
# -----------------------------
def get_finetunable_resnet50():
    resnet = models.resnet50(pretrained=True)
    for p in resnet.parameters():
        p.requires_grad = False
    for name, param in resnet.named_parameters():
        if "layer4" in name:
            param.requires_grad = True
    resnet.fc = nn.Identity()
    return resnet

# -----------------------------
# FINETUNABLE BERT
# -----------------------------
def get_finetunable_bert():
    bert = BertModel.from_pretrained("bert-base-uncased")
    for p in bert.parameters():
        p.requires_grad = False
    for name, param in bert.named_parameters():
        if any(f"encoder.layer.{i}" in name for i in range(8, 12)):
            param.requires_grad = True
    return bert

# -----------------------------
# MULTIMODAL TRANSFORMER
# -----------------------------
class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=4, num_heads=8):
        super().__init__()
        self.resnet = get_finetunable_resnet50()
        self.bert = get_finetunable_bert()

        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)

        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, 64, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)
        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)

        txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)

        x = torch.cat([img_tokens, txt_tokens], dim=1)

        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long),
            torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)

        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])


###Training

In [None]:

print('Starting the training')

# -----------------------------
# MAIN TRAINING LOOP + EVALUATION
# -----------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Create datasets
train_dataset = VQADataset(
    train_questions, train_answers, train_image_ids, train_image_paths,
    label_map, transform, tokenizer
)
test_dataset = VQADataset(
    test_questions, test_answers, test_image_ids, test_image_paths,
    label_map, transform, tokenizer
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQAFusionTransformer(num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for epoch in range(10):  # Reduced epochs for demo
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in train_loader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
print('Training done')


###Evaluation

In [None]:

# -----------------------------
# EVALUATION LOOP ON TEST SET
# -----------------------------
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, input_ids, attention_mask, labels in test_loader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(images, input_ids, attention_mask)
        preds = torch.argmax(outputs, dim=1).cpu()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.cpu().tolist())

# Compute metrics
acc = accuracy_score(all_labels, all_preds)
print(f"\nTest Accuracy: {acc * 100:.2f}%")

f1_macro = f1_score(all_labels, all_preds, average='macro')
f1_weighted = f1_score(all_labels, all_preds, average='weighted')
print(f"F1 Score (macro): {f1_macro:.4f}")
print(f"F1 Score (weighted): {f1_weighted:.4f}")

# Full classification report
target_names = list(label_map.keys())
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=target_names))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=target_names, yticklabels=target_names, cmap='Blues')
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

## MobilenetV2,DistilBert

* Image augmentation
* 20 epochs

In [None]:
!pip install easy_vqa
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import DistilBertTokenizer, DistilBertModel
from PIL import Image
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from easy_vqa import get_train_questions, get_test_questions, get_train_image_paths, get_test_image_paths
from collections import defaultdict
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

# -----------------------------
# CONFIGURATION FOR FULL DATASET
# -----------------------------
config = {
    'batch_size': 128,  # Larger batch size for full dataset
    'lr': 3e-5,
    'epochs': 20,
    'hidden_size': 256,
    'num_workers': 4,
    'image_size': 128,  # Kept small for speed
    'validate_every': 2,  # Validate every 2 epochs
    'max_seq_length': 20  # For tokenizer
}

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)

# -----------------------------
# FULL DATA LOADING
# -----------------------------
print("Loading full dataset...")
train_questions, train_answers, train_image_ids = get_train_questions()
test_questions, test_answers, test_image_ids = get_test_questions()
train_image_paths = get_train_image_paths()
test_image_paths = get_test_image_paths()

# Create label map from all answers
unique_answers = sorted(list(set(train_answers + test_answers)))
label_map = {ans: i for i, ans in enumerate(unique_answers)}
num_classes = len(label_map)
print(f"Training samples: {len(train_questions)}")
print(f"Test samples: {len(test_questions)}")
print(f"Number of classes: {num_classes}")

# -----------------------------
# OPTIMIZED DATASET CLASS
# -----------------------------
class EasyVQADataset(Dataset):
    def __init__(self, questions, answers, image_ids, image_paths, label_map, transform, tokenizer, max_len=20):
        self.questions = questions
        self.answers = answers
        self.image_ids = image_ids
        self.image_paths = image_paths
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        img_path = self.image_paths[self.image_ids[idx]]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # Tokenize on-the-fly (more memory efficient)
        encoding = self.tokenizer(
            self.questions[idx],
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        return (
            image,
            encoding['input_ids'].squeeze(0),
            encoding['attention_mask'].squeeze(0),
            torch.tensor(self.label_map[self.answers[idx]])
        )

# -----------------------------
# LIGHTWEIGHT MODEL ARCHITECTURE
# -----------------------------
class FastEasyVQAModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Lightweight image encoder (MobileNetV2)
        self.img_encoder = models.mobilenet_v2(pretrained=True)
        self.img_encoder.classifier = nn.Identity()  # Remove final layer

        # Freeze early layers, unfreeze last few
        for param in self.img_encoder.parameters():
            param.requires_grad = False
        for param in self.img_encoder.features[-4:].parameters():
            param.requires_grad = True

        # Lightweight text encoder (DistilBERT)
        self.txt_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        # Freeze most layers, unfreeze last layer
        for param in self.txt_encoder.parameters():
            param.requires_grad = False
        for param in self.txt_encoder.transformer.layer[-1].parameters():
            param.requires_grad = True

        # Efficient fusion and classifier
        self.classifier = nn.Sequential(
            nn.Linear(1280 + 768, 512),  # MobileNetV2 (1280) + DistilBERT (768)
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        # Get image features
        img_features = self.img_encoder(images)

        # Get text features (using CLS token)
        txt_outputs = self.txt_encoder(input_ids=input_ids, attention_mask=attention_mask)
        txt_features = txt_outputs.last_hidden_state[:, 0, :]

        # Concatenate and classify
        combined = torch.cat([img_features, txt_features], dim=1)
        return self.classifier(combined)

# -----------------------------
# TRAINING SETUP
# -----------------------------
print("Setting up training...")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Image transforms with augmentation
transform = transforms.Compose([
    transforms.Resize((config['image_size'], config['image_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Test transforms without augmentation
test_transform = transforms.Compose([
    transforms.Resize((config['image_size'], config['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = EasyVQADataset(
    train_questions, train_answers, train_image_ids, train_image_paths,
    label_map, transform, tokenizer, max_len=config['max_seq_length']
)
test_dataset = EasyVQADataset(
    test_questions, test_answers, test_image_ids, test_image_paths,
    label_map, test_transform, tokenizer, max_len=config['max_seq_length']
)

# Optimized data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    pin_memory=True,
    persistent_workers=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    pin_memory=True
)

model = FastEasyVQAModel(num_classes=num_classes).to(device)

# Optimizer with weight decay
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config['lr'],
    weight_decay=0.01
)
loss_fn = nn.CrossEntropyLoss()

# Mixed precision training
scaler = GradScaler()

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2, verbose=True
)

# -----------------------------
# TRAINING LOOP
# -----------------------------
print(f"Training on {len(train_questions)} samples...")
best_acc = 0.0
train_losses = []
val_accuracies = []

for epoch in range(config['epochs']):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")

    for images, input_ids, attn_mask, labels in progress_bar:
        images = images.to(device, non_blocking=True)
        input_ids = input_ids.to(device, non_blocking=True)
        attn_mask = attn_mask.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Mixed precision forward pass
        with autocast():
            outputs = model(images, input_ids, attn_mask)
            loss = loss_fn(outputs, labels)

        # Backward pass with scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validate only every N epochs
    if (epoch + 1) % config['validate_every'] == 0 or (epoch + 1) == config['epochs']:
        model.eval()
        val_loss, correct = 0.0, 0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for images, input_ids, attn_mask, labels in test_loader:
                images = images.to(device, non_blocking=True)
                input_ids = input_ids.to(device, non_blocking=True)
                attn_mask = attn_mask.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                outputs = model(images, input_ids, attn_mask)
                loss = loss_fn(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss /= len(test_loader)
        val_acc = correct / len(test_dataset)
        val_accuracies.append(val_acc)

        # Update learning rate
        scheduler.step(val_acc)

        print(f"\nEpoch {epoch+1}/{config['epochs']}: "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Val Acc: {val_acc:.2%}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_easyvqa_model.pth')
            print("Saved new best model!")

# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(np.linspace(0, config['epochs'], len(val_accuracies)), val_accuracies, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# -----------------------------
# COMPREHENSIVE EVALUATION
# -----------------------------
print("\nFinal Evaluation:")
print(classification_report(all_labels, all_preds, target_names=list(label_map.keys())))

# Confusion matrix
plt.figure(figsize=(15, 12))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_map.keys(),
            yticklabels=label_map.keys(), cmap='Blues')
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - Full EasyVQA Dataset")
plt.show()

# -----------------------------
# PREDICTION VISUALIZATION
# -----------------------------
def visualize_predictions(model, tokenizer, transform, test_data, image_paths, label_map, num_samples=5):
    model.eval()
    reverse_label_map = {v: k for k, v in label_map.items()}

    # Select random samples
    if len(test_data) > num_samples:
        indices = np.random.choice(len(test_data), num_samples, replace=False)
        test_data = [test_data[i] for i in indices]

    plt.figure(figsize=(15, 3*num_samples))

    with torch.no_grad():
        for i, item in enumerate(test_data):
            img_path = image_paths[item["image_id"]]
            try:
                # Load and process image
                image = Image.open(img_path).convert("RGB")
                display_image = image.copy()
                image = transform(image).unsqueeze(0).to(device)

                # Tokenize question
                encoding = tokenizer(
                    item["question"],
                    padding="max_length",
                    truncation=True,
                    max_length=config['max_seq_length'],
                    return_tensors="pt"
                ).to(device)

                # Predict
                outputs = model(image, encoding["input_ids"], encoding["attention_mask"])
                probs = torch.nn.functional.softmax(outputs, dim=1)
                confidence, pred_id = torch.max(probs, dim=1)
                pred_label = reverse_label_map[pred_id.item()]

                # Get top predictions
                top3_probs, top3_ids = torch.topk(probs, 3)
                top3_labels = [reverse_label_map[idx.item()] for idx in top3_ids[0]]
                top3_confs = [f"{prob.item():.1%}" for prob in top3_probs[0]]

                # Display results
                plt.subplot(num_samples, 2, 2*i+1)
                plt.imshow(display_image)
                plt.title(f"Image {i+1}")
                plt.axis('off')

                plt.subplot(num_samples, 2, 2*i+2)
                plt.text(0.1, 0.9, f"Q: {item['question']}", fontsize=10)
                plt.text(0.1, 0.7, f"Predicted: {pred_label} ({confidence.item():.1%})",
                         fontsize=10, color='green')
                plt.text(0.1, 0.5, f"Actual: {item['answer']}", fontsize=10, color='blue')
                plt.text(0.1, 0.3, "Top Predictions:", fontsize=10)
                for j in range(3):
                    plt.text(0.15, 0.2-0.1*j,
                             f"{j+1}. {top3_labels[j]} ({top3_confs[j]})",
                             fontsize=9)
                plt.axis('off')

            except Exception as e:
                print(f"Error processing {img_path}: {str(e)}")

    plt.tight_layout()
    plt.show()

# Prepare test data with answers
test_data_with_answers = [{
    "image_id": test_image_ids[i],
    "question": test_questions[i],
    "answer": test_answers[i]
} for i in range(len(test_questions))]

# Visualize predictions
print("\nVisualizing predictions...")
visualize_predictions(
    model=model,
    tokenizer=tokenizer,
    transform=test_transform,
    test_data=test_data_with_answers,
    image_paths=test_image_paths,
    label_map=label_map,
    num_samples=5
)

# -----------------------------
# MODEL SAVING
# -----------------------------
def save_full_model(model, tokenizer, label_map, config):
    """Save all components needed for deployment"""
    import pickle

    # Save model weights
    torch.save(model.state_dict(), 'easyvqa_full_model.pth')

    # Save tokenizer
    tokenizer.save_pretrained('./easyvqa_tokenizer')

    # Save metadata
    with open('easyvqa_metadata.pkl', 'wb') as f:
        pickle.dump({
            'label_map': label_map,
            'config': config,
            'image_size': config['image_size'],
            'max_seq_length': config['max_seq_length']
        }, f)

    print("\nModel artifacts saved:")
    print("- Model weights: easyvqa_full_model.pth")
    print("- Tokenizer: easyvqa_tokenizer/")
    print("- Metadata: easyvqa_metadata.pkl")

save_full_model(model, tokenizer, label_map, config)

#Path-VQA

###Dataset loading and setup

In [None]:
import os
import random
from tqdm.auto import tqdm
import io
from io import BytesIO
import requests
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import dask.dataframe as dd
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,models
from transformers import BertTokenizer,BertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,classification_report

print('Importing done')

splits = {
    'train': 'data/train-*-of-*.parquet',
    'validation': 'data/validation-*-of-*.parquet',
    'test': 'data/test-*-of-*.parquet'
}

train_df = pd.DataFrame(dd.read_parquet("hf://datasets/flaviagiammarino/path-vqa/" + splits["train"]).compute())
val_df = pd.DataFrame(dd.read_parquet("hf://datasets/flaviagiammarino/path-vqa/" + splits["validation"]).compute())
test_df = pd.DataFrame(dd.read_parquet("hf://datasets/flaviagiammarino/path-vqa/" + splits["test"]).compute())
# -----------------------------
# Build Unified Label Map
# -----------------------------

all_answers = pd.concat([train_df['answer'], val_df['answer'], test_df['answer']])
print(f'Total number of answers: {len(all_answers)}')
unique_answers = sorted(all_answers.unique())
label_map = {ans: idx for idx, ans in enumerate(unique_answers)}

print(f"Total unique answers across splits: {len(label_map)}")

print('Label map done')

# -----------------------------
# Set Seed
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# -----------------------------
# Dataset for Path-VQA
# -----------------------------
class PathVQADataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(io.BytesIO(row['image']['bytes'])).convert("RGB")
        img = self.transform(img)

        tokens = self.tokenizer(
            row['question'],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Datasets
train_dataset = PathVQADataset(train_df, label_map, transform, tokenizer)
val_dataset = PathVQADataset(val_df, label_map, transform, tokenizer)
test_dataset = PathVQADataset(test_df, label_map, transform, tokenizer)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
print("Data preparation done")

# -----------------------------
# Model
# -----------------------------
def get_finetunable_resnet50():
    resnet = models.resnet50(pretrained=True)
    for p in resnet.parameters():
        p.requires_grad = False
    for name, param in resnet.named_parameters():
        if "layer4" in name:
            param.requires_grad = True
    resnet.fc = nn.Identity()
    return resnet

def get_finetunable_bert():
    bert = BertModel.from_pretrained("bert-base-uncased")
    for p in bert.parameters():
        p.requires_grad = False
    for name, param in bert.named_parameters():
        if any(f"encoder.layer.{i}" in name for i in range(8, 12)):
            param.requires_grad = True
    return bert

class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=4, num_heads=8, max_len=64):
        super().__init__()
        self.max_len = max_len
        self.resnet = get_finetunable_resnet50()
        self.bert = get_finetunable_bert()

        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)

        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)
        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)

        txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)

        x = torch.cat([img_tokens, txt_tokens], dim=1)

        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long),
            torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)

        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

###Training

In [None]:
# -----------------------------
# Train and Evaluate Inline
# -----------------------------
# Assume train_loader, val_loader, test_loader, label_map are already defined

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQAFusionTransformer(num_classes=len(label_map)).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss()

train_losses = []
val_accuracies = []
val_f1s = []

best_val_acc = 0.0
best_model_state = None
patience = 3
epochs_without_improvement = 0

for epoch in range(10):  # Track loss and accuracy per epoch
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]"):
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(images, input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1).cpu()
            val_preds.extend(preds.tolist())
            val_labels.extend(labels.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_accuracies.append(val_acc)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_f1s.append(val_f1)

    print(f"\nEpoch {epoch+1}: Train Loss = {avg_train_loss:.4f} | Val Acc = {val_acc*100:.2f}% | F1 = {val_f1:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()
        epochs_without_improvement = 0
        print("✅ Best model updated")
    else:
        epochs_without_improvement += 1
        print(f"⚠️ No improvement for {epochs_without_improvement} epoch(s)")

        if epochs_without_improvement >= patience:
            print("⏹️ Early stopping triggered.")
            break

# Plot training curve
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Train Loss", marker='o')
plt.title("Training Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot([v*100 for v in val_accuracies], label="Val Accuracy", marker='o')
plt.plot([v*100 for v in val_f1s], label="Val F1 (Macro)", marker='x')
plt.title("Validation Accuracy & F1")
plt.xlabel("Epoch")
plt.ylabel("%")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Save the best mode
if best_model_state is not None:
    torch.save(best_model_state, "best_vqa_model.pth")
    print("✅ Model saved to 'best_vqa_model.pth'")


###Evaluation

In [None]:
# Load best model
model.load_state_dict(best_model_state)

# Final Evaluation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for images, input_ids, attention_mask, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(images, input_ids, attention_mask)
        preds = torch.argmax(outputs, dim=1).cpu()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.cpu().tolist())

acc = accuracy_score(all_labels, all_preds)
f1_macro = f1_score(all_labels, all_preds, average='macro')
f1_weighted = f1_score(all_labels, all_preds, average='weighted')

# Invert label_map to get index -> answer
inv_label_map = {v: k for k, v in label_map.items()}

# Extract all unique labels present in test
present_labels = sorted(set(all_preds + all_labels))

# Debug print
print(f"Predicted/Test Labels: {len(present_labels)}")
print(f"Total label_map classes: {len(label_map)}")

# Now compute classification report
print("\nClassification Report:")
print(classification_report(
    all_labels,
    all_preds,
    labels=present_labels,
    target_names=[inv_label_map[i] for i in present_labels]
))

print(f"\nTest Accuracy: {acc*100:.2f}%")
print(f"F1 Macro: {f1_macro:.4f} | F1 Weighted: {f1_weighted:.4f}")
from collections import Counter

# Count top classes in predictions and labels
combined = all_preds + all_labels
top_classes = [item for item, _ in Counter(combined).most_common(10)]

# Filter predictions and labels to only include top classes
filtered_preds = [p for p, t in zip(all_preds, all_labels) if p in top_classes and t in top_classes]
filtered_labels = [t for p, t in zip(all_preds, all_labels) if p in top_classes and t in top_classes]

# Get the confusion matrix
cm = confusion_matrix(filtered_labels, filtered_preds, labels=top_classes)

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True,
            xticklabels=[inv_label_map[i] for i in top_classes],
            yticklabels=[inv_label_map[i] for i in top_classes],
            cmap='Blues', fmt='g')
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix (Top 10 Classes Only)")
plt.tight_layout()
plt.show()


# VQA-RAD

With grad-cam

From : vqa.ipynb

###Dataset loading and setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
import random
import io
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
splits = {
    'train': 'data/train-00000-of-00001-eb8844602202be60.parquet',
    'test': 'data/test-00000-of-00001-e5bc3d208bb4deeb.parquet'
}

train_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["train"])
test_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["test"])
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)
all_answers = pd.concat([train_df['answer'], val_df['answer'], test_df['answer']])
label_map_vqarad = {ans: idx for idx, ans in enumerate(sorted(all_answers.unique()))}

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# -----------------------------
# Dataset Class
# -----------------------------
class VQARADDataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(io.BytesIO(row['image']['bytes'])).convert("RGB")
        img = self.transform(img)

        tokens = self.tokenizer(
            row['question'], padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt"
        )

        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label

# -----------------------------
# Transforms and Tokenizer
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = VQARADDataset(train_df, label_map_vqarad, transform, tokenizer)
val_dataset = VQARADDataset(val_df, label_map_vqarad, transform, tokenizer)
test_dataset = VQARADDataset(test_df, label_map_vqarad, transform, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# -----------------------------
# Model Definition
# -----------------------------
def get_finetunable_resnet50():
    resnet = models.resnet50(pretrained=True)
    for p in resnet.parameters(): p.requires_grad = False
    for name, param in resnet.named_parameters():
        if "layer4" in name:
            param.requires_grad = True
    resnet.fc = nn.Identity()
    return resnet

def get_finetunable_bert():
    bert = BertModel.from_pretrained("bert-base-uncased")
    for p in bert.parameters(): p.requires_grad = False
    for name, param in bert.named_parameters():
        if any(f"encoder.layer.{i}" in name for i in range(8, 12)):
            param.requires_grad = True
    return bert

class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=8, num_heads=8, max_len=64):
        super().__init__()
        self.resnet = get_finetunable_resnet50()
        self.bert = get_finetunable_bert()
        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)
        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls = nn.Sequential(
            nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)
        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)
        txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)
        x = torch.cat([img_tokens, txt_tokens], dim=1)
        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long), torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)
        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

###Training

In [None]:
# -----------------------------
# Training
# -----------------------------
# Load pretrained weights (excluding mismatched layers) on GPU
def load_pretrained_weights(model, pretrained_path, device='cuda'):
    print(f"🔄 Loading pretrained model from: {pretrained_path}")

    # Load the full checkpoint to CPU first
    pretrained_dict = torch.load(pretrained_path, map_location=device)

    model_dict = model.state_dict()

    # Filter out keys that mismatch in shape (like final classification layer)
    filtered_dict = {
        k: v for k, v in pretrained_dict.items()
        if k in model_dict and model_dict[k].shape == v.shape
    }

    print(f"✅ {len(filtered_dict)}/{len(pretrained_dict)} parameters loaded from pretrained model.")

    # Update and load the state_dict
    model_dict.update(filtered_dict)
    model.load_state_dict(model_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example usage
model = VQAFusionTransformer(num_classes=len(label_map_vqarad)).to(device)
load_pretrained_weights(model, "/content/drive/MyDrive/My Folder/best_vqa_model.pth", device=device)
print("✅ Loaded PathVQA pre-trained model.")
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)


train_losses, val_accuracies, val_f1s = [], [], []
best_val_acc = 0.0
best_model_state = None
patience = 10
epochs_without_improvement = 0

for epoch in range(80):
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()

    train_losses.append(total_loss / len(train_loader))
    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_accuracies.append(val_acc)
    val_f1s.append(val_f1)
    print(f"Epoch {epoch+1}: Loss={train_losses[-1]:.4f} | Val Acc={val_acc*100:.2f}% | F1={val_f1:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break
epochs_ran = len(train_losses)

plt.figure(figsize=(14, 5))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs_ran + 1), train_losses, label='Train Loss', marker='o', color='red')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()

# Plot Accuracy & F1
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs_ran + 1), [v * 100 for v in val_accuracies], label='Val Accuracy (%)', marker='s', color='blue')
plt.plot(range(1, epochs_ran + 1), [f * 100 for f in val_f1s], label='Val F1 Macro (%)', marker='x', color='green')
plt.title('Validation Accuracy and F1')
plt.xlabel('Epoch')
plt.ylabel('Percentage')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

###Evaluation

In [None]:

#testing
model.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for images, input_ids, attention_mask, labels in tqdm(test_loader, desc="Test"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        labels = labels.to(device)
        preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())

print("Test Accuracy:", accuracy_score(test_labels, test_preds) * 100)
print("Test F1 (macro):", f1_score(test_labels, test_preds, average='macro') * 100)
print("\nClassification Report:\n", classification_report(test_labels, test_preds))
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        self.hook_handles.append(self.target_layer.register_forward_hook(forward_hook))
        self.hook_handles.append(self.target_layer.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

    def generate(self, image_tensor, input_ids, attention_mask, class_idx=None):
        # Ensure we're in eval mode but with gradients enabled
        self.model.eval()

        # Prepare inputs
        image_tensor = image_tensor.unsqueeze(0).to(device).requires_grad_()
        input_ids = input_ids.unsqueeze(0).to(device)
        attention_mask = attention_mask.unsqueeze(0).to(device)

        # Forward pass with gradients
        with torch.set_grad_enabled(True):
            output = self.model(image_tensor, input_ids, attention_mask)

            if class_idx is None:
                class_idx = output.argmax(dim=1).item()

            # Zero gradients
            self.model.zero_grad()

            # Backward pass for specific class
            one_hot = torch.zeros_like(output)
            one_hot[0, class_idx] = 1
            output.backward(gradient=one_hot, retain_graph=True)

        # Process gradients and activations
        gradients = self.gradients
        activations = self.activations

        # Pool gradients and weight activations
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
        for i in range(activations.shape[1]):
            activations[:, i, :, :] *= pooled_gradients[i]

        # Generate heatmap
        heatmap = torch.mean(activations, dim=1).squeeze()
        heatmap = torch.relu(heatmap)  # Apply ReLU
        heatmap = heatmap / torch.max(heatmap)  # Normalize

        return heatmap.cpu().numpy(), class_idx

# Visualization function
def show_gradcam_comparison(img_tensor, heatmap, gt_label, pred_label, is_correct):
    import cv2
    import matplotlib.cm as cm

    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    img = (img - img.min()) / (img.max() - img.min())

    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap_colored = cm.jet(heatmap)[..., :3]
    superimposed_img = heatmap_colored * 0.5 + img * 0.5

    plt.figure(figsize=(12, 4))

    # Original image
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title(f"Original Image\nGT: {gt_label}")
    plt.axis('off')

    # Grad-CAM
    plt.subplot(1, 2, 2)
    plt.imshow(superimposed_img)
    plt.title(f"Grad-CAM\nPred: {pred_label}\n{'CORRECT' if is_correct else 'INCORRECT'}",
             color='green' if is_correct else 'red')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# Main visualization code
grad_cam = GradCAM(model, model.resnet.layer4)

num_samples = 3
correct_count = 0
incorrect_count = 0

model.eval()
for idx, (img, input_ids, attention_mask, label) in enumerate(test_dataset):
    if correct_count >= num_samples and incorrect_count >= num_samples:
        break

    # Get prediction
    with torch.no_grad():
        img_tensor = img.unsqueeze(0).to(device)
        input_ids_tensor = input_ids.unsqueeze(0).to(device)
        attention_mask_tensor = attention_mask.unsqueeze(0).to(device)

        output = model(img_tensor, input_ids_tensor, attention_mask_tensor)
        pred_class = output.argmax(dim=1).item()

    is_correct = (pred_class == label.item())

    # Skip if we have enough samples
    if is_correct and correct_count >= num_samples:
        continue
    if not is_correct and incorrect_count >= num_samples:
        continue

    # Generate Grad-CAM
    heatmap, _ = grad_cam.generate(img, input_ids, attention_mask)

    # Get label names
    gt_label_name = list(label_map_vqarad.keys())[list(label_map_vqarad.values()).index(label.item())]
    pred_label_name = list(label_map_vqarad.keys())[pred_class]

    # Show visualization
    show_gradcam_comparison(img, heatmap, gt_label_name, pred_label_name, is_correct)

    # Update counters
    if is_correct:
        correct_count += 1
    else:
        incorrect_count += 1

grad_cam.remove_hooks()

## Grid-search

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image
import io, random, numpy as np
from tqdm.auto import tqdm

# 1. Set seed and device
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Load pretrained weights correctly
def load_pretrained_weights(model, path, device):
    checkpoint = torch.load(path, map_location=device)
    model_dict = model.state_dict()
    filtered = {k: v.to(device) for k, v in checkpoint.items() if k in model_dict and model_dict[k].shape == v.shape}
    print(f"Loaded {len(filtered)}/{len(model_dict)} parameters.")
    model_dict.update(filtered)
    model.load_state_dict(model_dict)
    model.to(device)

# 3. Dataset Class
class VQARADDataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(io.BytesIO(row['image']['bytes'])).convert("RGB")
        img = self.transform(img)
        tokens = self.tokenizer(row['question'], padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label

# 4. Model
class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=4, num_heads=8, max_len=64):
        super().__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.resnet.fc = nn.Identity()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)
        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls = nn.Sequential(nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes))

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)
        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)
        txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)
        x = torch.cat([img_tokens, txt_tokens], dim=1)
        mod_ids = torch.cat([torch.ones((B, 1), dtype=torch.long), torch.zeros((B, input_ids.size(1)), dtype=torch.long)], dim=1).to(images.device)
        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

# 5. Training loop

def train_vqa_model_finetune_from_pathvqa(pathvqa_model_path, train_loader, val_loader, label_map, lr=1e-5,
                                           resnet_unfreeze=False, bert_unfreeze=False, tf_layers=4):
    model = VQAFusionTransformer(num_classes=len(label_map), num_layers=tf_layers).to(device)
    load_pretrained_weights(model, pathvqa_model_path, device)

    # Unfreeze layers as needed
    for name, param in model.resnet.named_parameters():
        if resnet_unfreeze and "layer4" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    for name, param in model.bert.named_parameters():
        if bert_unfreeze and any(f"encoder.layer.{i}" in name for i in range(8, 12)):
            param.requires_grad = True
        else:
            param.requires_grad = False

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    model.train()
    for images, input_ids, attention_mask, labels in tqdm(train_loader):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in val_loader:
            images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
            labels = labels.to(device)
            preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())

    acc = accuracy_score(val_labels, val_preds)
    f1 = f1_score(val_labels, val_preds, average='macro')
    return acc, f1

# 6. Grid Search Example
learning_rates = [1e-5, 3e-5]
resnet_opts = [False, True]
bert_opts = [False, True]
transformer_layers = [2, 4]

results = []

for lr in learning_rates:
    for resnet_unfreeze in resnet_opts:
        for bert_unfreeze in bert_opts:
            for tf_layer in transformer_layers:
                print(f"\n🔧 Running: lr={lr}, resnet={resnet_unfreeze}, bert={bert_unfreeze}, layers={tf_layer}")
                acc, f1 = train_vqa_model_finetune_from_pathvqa(
                    pathvqa_model_path="/content/drive/MyDrive/vqa_models/best_vqa_model.pth",
                    train_loader=train_loader, val_loader=val_loader,
                    label_map=label_map_vqarad, lr=lr,
                    resnet_unfreeze=resnet_unfreeze, bert_unfreeze=bert_unfreeze,
                    tf_layers=tf_layer
                )
                results.append({
                    "lr": lr,
                    "resnet_unfreeze": resnet_unfreeze,
                    "bert_unfreeze": bert_unfreeze,
                    "tf_layers": tf_layer,
                    "val_acc": acc,
                    "val_f1": f1
                })
                print(f"✅ Done: Acc={acc:.4f}, F1={f1:.4f}")

import pandas as pd
results_df = pd.DataFrame(results)
print(results_df.sort_values(by="val_acc", ascending=False))


##Bio-BERT

In [None]:
import os, random, io
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import AutoTokenizer, AutoModel  # For BioBERT

# -----------------------------
# Setup
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Dataset Loading
# -----------------------------
splits = {
    'train': 'data/train-00000-of-00001-eb8844602202be60.parquet',
    'test': 'data/test-00000-of-00001-e5bc3d208bb4deeb.parquet'
}

train_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["train"])
test_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["test"])
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

all_answers = pd.concat([train_df['answer'], val_df['answer'], test_df['answer']])
label_map_vqarad = {ans: idx for idx, ans in enumerate(sorted(all_answers.unique()))}

# -----------------------------
# Dataset Class
# -----------------------------
class VQARADDataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(io.BytesIO(row['image']['bytes'])).convert("RGB")
        img = self.transform(img)

        tokens = self.tokenizer(
            row['question'],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label

# -----------------------------
# Transforms & Tokenizer
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

train_dataset = VQARADDataset(train_df, label_map_vqarad, transform, tokenizer)
val_dataset = VQARADDataset(val_df, label_map_vqarad, transform, tokenizer)
test_dataset = VQARADDataset(test_df, label_map_vqarad, transform, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# -----------------------------
# Model: ResNet + BioBERT + Transformer
# -----------------------------
class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=8, num_heads=8, max_len=64):
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False
        for name, param in self.resnet.named_parameters():
            if "layer4" in name:
                param.requires_grad = True
        self.resnet.fc = nn.Identity()

        self.biobert = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
        for param in self.biobert.parameters():
            param.requires_grad = False
        for name, param in self.biobert.named_parameters():
            if any(f"encoder.layer.{i}" in name for i in range(8, 12)):
                param.requires_grad = True

        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)
        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls = nn.Sequential(
            nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)

        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)

        txt_out = self.biobert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)

        x = torch.cat([img_tokens, txt_tokens], dim=1)
        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long),
            torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)

        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

# -----------------------------
# Training + Evaluation
# -----------------------------
model = VQAFusionTransformer(num_classes=len(label_map_vqarad)).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

best_model_state = None
train_losses, val_accuracies, val_f1s = [], [], []
best_val_acc, patience, epochs_without_improvement = 0.0, 5, 0

for epoch in range(30):
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()

    train_losses.append(total_loss / len(train_loader))

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in val_loader:
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_accuracies.append(val_acc)
    val_f1s.append(val_f1)

    print(f"Epoch {epoch+1}: Loss={train_losses[-1]:.4f} | Val Acc={val_acc*100:.2f}% | F1={val_f1:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

# Save
torch.save(best_model_state, "vqarad_biobert_finetuned_model.pth")

# -----------------------------
# Test Evaluation
# -----------------------------
model.load_state_dict(best_model_state)
model.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for images, input_ids, attention_mask, labels in tqdm(test_loader, desc="Testing"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        labels = labels.to(device)
        preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())

print("Test Accuracy:", accuracy_score(test_labels, test_preds) * 100)
print("Test F1 (macro):", f1_score(test_labels, test_preds, average='macro') * 100)
print("\nClassification Report:\n", classification_report(test_labels, test_preds))


##Clinical-BERT

In [None]:
import os, random, io
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import AutoTokenizer, AutoModel  # For BioBERT

# -----------------------------
# Setup
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Dataset Loading
# -----------------------------
splits = {
    'train': 'data/train-00000-of-00001-eb8844602202be60.parquet',
    'test': 'data/test-00000-of-00001-e5bc3d208bb4deeb.parquet'
}

train_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["train"])
test_df = pd.read_parquet("hf://datasets/flaviagiammarino/vqa-rad/" + splits["test"])
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

all_answers = pd.concat([train_df['answer'], val_df['answer'], test_df['answer']])
label_map_vqarad = {ans: idx for idx, ans in enumerate(sorted(all_answers.unique()))}

# -----------------------------
# Dataset Class
# -----------------------------
class VQARADDataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(io.BytesIO(row['image']['bytes'])).convert("RGB")
        img = self.transform(img)

        tokens = self.tokenizer(
            row['question'],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label

# -----------------------------
# Transforms & Tokenizer
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

train_dataset = VQARADDataset(train_df, label_map_vqarad, transform, tokenizer)
val_dataset = VQARADDataset(val_df, label_map_vqarad, transform, tokenizer)
test_dataset = VQARADDataset(test_df, label_map_vqarad, transform, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# -----------------------------
# Model: ResNet + BioBERT + Transformer
# -----------------------------
class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=8, num_heads=8, max_len=64):
        super().__init__()
        self.resnet = models.resnet50(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False
        for name, param in self.resnet.named_parameters():
            if "layer4" in name:
                param.requires_grad = True
        self.resnet.fc = nn.Identity()

        self.biobert = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
        for param in self.biobert.parameters():
            param.requires_grad = False
        for name, param in self.biobert.named_parameters():
            if any(f"encoder.layer.{i}" in name for i in range(8, 12)):
                param.requires_grad = True

        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)
        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.cls = nn.Sequential(
            nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)

        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)

        txt_out = self.biobert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)

        x = torch.cat([img_tokens, txt_tokens], dim=1)
        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long),
            torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)

        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

# -----------------------------
# Training + Evaluation
# -----------------------------
model = VQAFusionTransformer(num_classes=len(label_map_vqarad)).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

best_model_state = None
train_losses, val_accuracies, val_f1s = [], [], []
best_val_acc, patience, epochs_without_improvement = 0.0, 5, 0

for epoch in range(30):
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()

    train_losses.append(total_loss / len(train_loader))

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in val_loader:
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_accuracies.append(val_acc)
    val_f1s.append(val_f1)

    print(f"Epoch {epoch+1}: Loss={train_losses[-1]:.4f} | Val Acc={val_acc*100:.2f}% | F1={val_f1:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

# Save
torch.save(best_model_state, "vqarad_biobert_finetuned_model.pth")

# -----------------------------
# Test Evaluation
# -----------------------------
model.load_state_dict(best_model_state)
model.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for images, input_ids, attention_mask, labels in tqdm(test_loader, desc="Testing"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        labels = labels.to(device)
        preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())

print("Test Accuracy:", accuracy_score(test_labels, test_preds) * 100)
print("Test F1 (macro):", f1_score(test_labels, test_preds, average='macro') * 100)
print("\nClassification Report:\n", classification_report(test_labels, test_preds))


#Botany-VQA

### Dataset loading and setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Step 1: Download Oxford Flowers 102 dataset
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat

# Step 2: Extract images
!tar -xvzf 102flowers.tgz

import os
import random
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report

# -----------------------------
# Set seed
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# -----------------------------
# Read and fix your CSV
# -----------------------------
df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/botany_vqa_v1.csv")
#in a similar way load the 2nd version as well
print("✅ Loaded CSV:", df.shape)

# Apply 5-digit zero padding fix to filenames
def fix_image_path(img_path):
    num = int(img_path.split('_')[-1].split('.')[0])
    return f"jpg/image_{num:05d}.jpg"

df['image_path'] = df['image_path'].apply(fix_image_path)

# -----------------------------
# Label mapping
# -----------------------------
label_map = {ans: idx for idx, ans in enumerate(sorted(df['answer'].unique()))}
print("✅ Classes:", label_map)

# -----------------------------
# Train/Val/Test Split
# -----------------------------
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

# -----------------------------
# Dataset Class
# -----------------------------
class BotanyVQADataset(Dataset):
    def __init__(self, df, label_map, transform, tokenizer, max_len=64):
        self.df = df.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['image_path']
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        tokens = tokenizer(
            row['question'], padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )

        label = torch.tensor(self.label_map[row['answer']])
        return img, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0), label

# -----------------------------
# Transforms and Tokenizer
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = BotanyVQADataset(train_df, label_map, transform, tokenizer)
val_dataset = BotanyVQADataset(val_df, label_map, transform, tokenizer)
test_dataset = BotanyVQADataset(test_df, label_map, transform, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# -----------------------------
# Model
# -----------------------------
def get_finetunable_resnet50():
    resnet = models.resnet50(pretrained=True)
    for p in resnet.parameters(): p.requires_grad = True
    resnet.fc = nn.Identity()
    return resnet

def get_finetunable_bert():
    bert = BertModel.from_pretrained("bert-base-uncased")
    for p in bert.parameters(): p.requires_grad = True
    return bert

class VQAFusionTransformer(nn.Module):
    def __init__(self, num_classes, d_model=512, num_layers=8, num_heads=8, max_len=64):
        super().__init__()
        self.resnet = get_finetunable_resnet50()
        self.bert = get_finetunable_bert()
        self.img_proj = nn.Linear(2048, d_model)
        self.txt_proj = nn.Linear(768, d_model)
        self.mod_embed = nn.Embedding(2, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls = nn.Sequential(
            nn.Linear(d_model, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        B = images.size(0)
        img_feat = self.resnet(images)
        img_tokens = self.img_proj(img_feat).unsqueeze(1)
        txt_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        txt_tokens = self.txt_proj(txt_out.last_hidden_state)
        x = torch.cat([img_tokens, txt_tokens], dim=1)
        mod_ids = torch.cat([
            torch.ones((B, 1), dtype=torch.long), torch.zeros((B, input_ids.size(1)), dtype=torch.long)
        ], dim=1).to(images.device)
        x = x + self.mod_embed(mod_ids) + self.pos_embed[:, :x.size(1), :]
        fused = self.encoder(x)
        return self.cls(fused[:, 0])

###Training

In [None]:

# -----------------------------
# Training
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQAFusionTransformer(num_classes=len(label_map)).to(device)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-5)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

train_losses, val_accuracies, val_f1s = [], [], []
best_val_acc = 0.0
patience = 30
epochs_without_improvement = 0

for epoch in range(10):
    model.train()
    total_loss = 0
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(images, input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()

    train_losses.append(total_loss / len(train_loader))

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    val_accuracies.append(val_acc)
    val_f1s.append(val_f1)
    print(f"Epoch {epoch+1}: Loss={train_losses[-1]:.4f} | Val Acc={val_acc*100:.2f}% | F1={val_f1:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

# -----------------------------
# Plotting
# -----------------------------
plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss', marker='o', color='red')
plt.title('Training Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.grid(True); plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(val_accuracies) + 1), [v * 100 for v in val_accuracies], label='Val Accuracy (%)', marker='s', color='blue')
plt.plot(range(1, len(val_f1s) + 1), [f * 100 for f in val_f1s], label='Val F1 Macro (%)', marker='x', color='green')
plt.title('Validation Accuracy and F1'); plt.xlabel('Epoch'); plt.ylabel('Percentage'); plt.grid(True); plt.legend()
plt.tight_layout()
plt.show()

###Evaluation

In [None]:

# -----------------------------
# Final Test Evaluation
# -----------------------------
model.load_state_dict(best_model_state)
model.eval()
test_preds, test_labels = [], []
with torch.no_grad():
    for images, input_ids, attention_mask, labels in tqdm(test_loader, desc="Test"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        labels = labels.to(device)
        preds = torch.argmax(model(images, input_ids, attention_mask), dim=1)
        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels.cpu().tolist())
class GradCAMExtractor:
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        target_layer = self.model.resnet.layer4
        self.hook_handles.append(target_layer.register_forward_hook(forward_hook))
        self.hook_handles.append(target_layer.register_backward_hook(backward_hook))

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

    def get_gradcam(self, class_idx=None):
        pooled_grads = torch.mean(self.gradients, dim=[0, 2, 3])  # [C]
        for i in range(self.activations.shape[1]):
            self.activations[:, i, :, :] *= pooled_grads[i]
        heatmap = torch.mean(self.activations, dim=1).squeeze()
        heatmap = torch.clamp(heatmap, min=0)
        heatmap /= torch.max(heatmap)
        return heatmap.cpu().numpy()

print("Test Accuracy:", accuracy_score(test_labels, test_preds) * 100)
print("Test F1 (macro):", f1_score(test_labels, test_preds, average='macro') * 100)
# print("\nClassification Report:\n", classification_report(test_labels, test_preds))

import cv2
def visualize_gradcam(model, extractor, dataset, correct=True, num=2):
    model.eval()
    shown = 0
    for i in range(len(dataset)):
        image, input_ids, attention_mask, label = dataset[i]
        img_tensor = image.unsqueeze(0).to(device)
        input_ids = input_ids.unsqueeze(0).to(device)
        attention_mask = attention_mask.unsqueeze(0).to(device)
        label = torch.tensor(label).unsqueeze(0).to(device)

        extractor.model.zero_grad()
        output = model(img_tensor, input_ids, attention_mask)
        pred = torch.argmax(output, dim=1)

        if ((pred == label) and correct) or ((pred != label) and not correct):
            output[0, pred].backward()
            heatmap = extractor.get_gradcam()

            img_np = image.permute(1, 2, 0).numpy()
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

            heatmap_resized = cv2.resize(heatmap, (224, 224))
            heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
            superimposed = cv2.addWeighted(np.uint8(255 * img_np), 0.6, heatmap_colored, 0.4, 0)

            plt.figure(figsize=(4, 4))
            plt.title(f"Pred: {pred.item()} | GT: {label.item()}")
            plt.imshow(superimposed)
            plt.axis('off')
            plt.show()

            shown += 1
            if shown >= num:
                break
extractor = GradCAMExtractor(model)

print("✅ Correct Predictions")
visualize_gradcam(model, extractor, test_dataset, correct=True, num=2)

print("❌ Incorrect Predictions")
visualize_gradcam(model, extractor, test_dataset, correct=False, num=2)

extractor.remove_hooks()


# Path in your Google Drive
model_save_path = "/content/drive/MyDrive/vqa_fusion_model.pth"

# Save the best model weights
torch.save(best_model_state, model_save_path)

print(f"✅ Model saved to: {model_save_path}")
