# VQA Medical - Training on Google Colab

This notebook allows you to train the VQA Medical model on Google Colab using the modular codebase.

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install dependencies
!pip install torch torchvision transformers tqdm matplotlib pillow -q

## 2. Upload/Clone Project

**Option A**: Clone from GitHub (if you pushed the repo)

In [None]:
# Option A: Clone from GitHub
# !git clone https://github.com/YOUR_USERNAME/vqa_med.git
# %cd vqa_med

**Option B**: Upload the vqamed package manually

Upload the `vqamed/` folder to Colab using the file browser on the left, or run the cell below to create the package inline.

In [None]:
# Option B: Create the package inline (run this if you didn't clone/upload)
import os
os.makedirs('vqamed', exist_ok=True)

In [None]:
%%writefile vqamed/__init__.py
"""VQA Medical - Visual Question Answering for Medical Images"""

from .config import Config
from .dataset import VQADataset, parse_qa_set
from .encoders import VisualEncoder, TextEncoder
from .fusion import CrossAttentionFusion
from .decoder import AnswerDecoder
from .model import VQAModel
from .training import train_epoch, validate_epoch, EarlyStopping

__all__ = [
    "Config",
    "VQADataset",
    "parse_qa_set",
    "VisualEncoder",
    "TextEncoder",
    "CrossAttentionFusion",
    "AnswerDecoder",
    "VQAModel",
    "train_epoch",
    "validate_epoch",
    "EarlyStopping",
]

In [None]:
%%writefile vqamed/config.py
"""Configuration for VQA Medical model."""

from dataclasses import dataclass, field
from pathlib import Path
import torch


@dataclass
class Config:
    """Configuration class for VQA Medical training."""
    
    # Paths
    train_path: str = "data/Training"
    validation_path: str = "data/Validation"
    test_path: str = "data/Test"
    save_path: str = "checkpoints/best_model.pt"
    
    # Model
    embed_dim: int = 512
    num_heads: int = 8
    decoder_layers: int = 4
    max_len: int = 32
    text_model_name: str = "bert-base-uncased"
    
    # Training
    batch_size: int = 32
    num_epochs: int = 30
    learning_rate: float = 1e-4
    patience: int = 5
    
    # Device
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    
    def __post_init__(self):
        """Create directories if they don't exist."""
        Path(self.save_path).parent.mkdir(parents=True, exist_ok=True)
    
    @property
    def train_images_dir(self) -> str:
        return f"{self.train_path}/images"
    
    @property
    def train_qa_file(self) -> str:
        return f"{self.train_path}/all_qa_pairs.txt"
    
    @property
    def val_images_dir(self) -> str:
        return f"{self.validation_path}/images"
    
    @property
    def val_qa_file(self) -> str:
        return f"{self.validation_path}/all_qa_pairs.txt"
    
    @property
    def test_images_dir(self) -> str:
        return f"{self.test_path}/images"
    
    @property
    def test_qa_file(self) -> str:
        return f"{self.test_path}/questions_w_ref_answers.txt"

In [None]:
%%writefile vqamed/dataset.py
"""Dataset module for VQA Medical."""

from pathlib import Path
from typing import Optional, Callable, Tuple, List, Dict, Any

import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms


def parse_qa_set(qa_pairs_txt_path: str) -> Tuple[int, List[Dict[str, str]]]:
    """Parse QA pairs from a text file."""
    qa_set = []

    with open(qa_pairs_txt_path, 'r', encoding='utf-8') as f:
        for line in f:
            elements = line.strip().split('|')
            if len(elements) > 3:
                image_id = elements[0]
                question = elements[2]
                answer = elements[3]
            else:
                image_id, question, answer = elements
            qa_set.append({
                'image_id': image_id,
                'question': question,
                'answer': answer
            })

    return len(qa_set), qa_set


class VQADataset(Dataset):
    """Dataset for Visual Question Answering on medical images."""
    
    def __init__(
        self,
        images_dir: str,
        qa_file: str,
        transform: Optional[Callable] = None,
        tokenizer: Optional[Any] = None,
        max_len: int = 32
    ):
        self.images_dir = Path(images_dir)
        self.length, self.items = parse_qa_set(qa_file)
        self.transform = transform or self._default_transform()
        self.tokenizer = tokenizer
        self.max_len = max_len

    @staticmethod
    def _default_transform() -> transforms.Compose:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.items[idx]
        img_path = self.images_dir / f"{item['image_id']}.jpg"
        
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        
        question = item['question']
        if self.tokenizer:
            tokens = self.tokenizer(
                question,
                padding='max_length',
                truncation=True,
                max_length=self.max_len,
                return_tensors='pt'
            )
            input_ids = tokens['input_ids'].squeeze(0)
            attention_mask = tokens['attention_mask'].squeeze(0)
        else:
            input_ids = question
            attention_mask = None

        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'answer': item['answer']
        }

In [None]:
%%writefile vqamed/encoders.py
"""Encoder modules for VQA Medical."""

import torch
import torch.nn as nn
import torchvision.models as models
from transformers import AutoModel


class VisualEncoder(nn.Module):
    """Visual encoder using DenseNet121 backbone."""
    
    def __init__(self, embed_dim: int = 512):
        super().__init__()
        base_model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        self.features = base_model.features
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.projection = nn.Sequential(
            nn.Linear(1024, 768),
            nn.ReLU(),
            nn.Linear(768, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.global_pool(x).view(x.size(0), -1)
        x = self.projection(x)
        return x


class TextEncoder(nn.Module):
    """Text encoder using BERT backbone."""
    
    def __init__(self, model_name: str = 'bert-base-uncased', embed_dim: int = 512):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        self.projection = nn.Sequential(
            nn.Linear(self.transformer.config.hidden_size, 512),
            nn.ReLU(),
            nn.Linear(512, embed_dim)
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        output = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = output.last_hidden_state[:, 0]
        return self.projection(cls_token)

In [None]:
%%writefile vqamed/fusion.py
"""Fusion module for VQA Medical."""

import torch
import torch.nn as nn


class CrossAttentionFusion(nn.Module):
    """Cross-attention fusion layer for multimodal features."""
    
    def __init__(self, dim: int = 512, heads: int = 8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True
        )
        self.norm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )

    def forward(
        self,
        text_embeds: torch.Tensor,
        image_embeds: torch.Tensor
    ) -> torch.Tensor:
        attn_output, _ = self.cross_attn(
            query=text_embeds,
            key=image_embeds,
            value=image_embeds
        )
        x = self.norm(attn_output + text_embeds)
        x = self.mlp(x)
        return x

In [None]:
%%writefile vqamed/decoder.py
"""Answer decoder module for VQA Medical."""

import torch
import torch.nn as nn


class AnswerDecoder(nn.Module):
    """Transformer decoder for generating answers."""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 512,
        max_len: int = 32,
        num_layers: int = 4,
        nhead: int = 8
    ):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_len, embed_dim)
        self.transformer = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(
                d_model=embed_dim,
                nhead=nhead,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.max_len = max_len
        self.embed_dim = embed_dim

    def forward(
        self,
        tgt_input_ids: torch.Tensor,
        memory: torch.Tensor,
        tgt_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        B, T = tgt_input_ids.size()
        device = tgt_input_ids.device

        pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
        tgt_emb = self.token_emb(tgt_input_ids) + self.pos_emb(pos_ids)

        causal_mask = torch.triu(
            torch.ones((T, T), device=device),
            diagonal=1
        ).bool()

        out = self.transformer(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=causal_mask
        )

        return self.fc_out(out)

In [None]:
%%writefile vqamed/model.py
"""Main VQA model combining all components."""

import torch
import torch.nn as nn

from .encoders import VisualEncoder, TextEncoder
from .fusion import CrossAttentionFusion
from .decoder import AnswerDecoder


class VQAModel(nn.Module):
    """Complete VQA model for medical image question answering."""
    
    def __init__(
        self,
        visual_encoder: VisualEncoder,
        text_encoder: TextEncoder,
        fusion: CrossAttentionFusion,
        decoder: AnswerDecoder
    ):
        super().__init__()
        self.visual_encoder = visual_encoder
        self.text_encoder = text_encoder
        self.fusion = fusion
        self.decoder = decoder

    def forward(
        self,
        image: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        decoder_input_ids: torch.Tensor
    ) -> torch.Tensor:
        img_feat = self.visual_encoder(image)
        text_out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)

        fused = self.fusion(text_out, img_feat)
        if fused.dim() == 2:
            fused = fused.unsqueeze(1)
            
        logits = self.decoder(
            tgt_input_ids=decoder_input_ids,
            memory=fused
        )

        return logits

    @classmethod
    def from_config(cls, config, vocab_size: int) -> "VQAModel":
        visual_encoder = VisualEncoder(embed_dim=config.embed_dim)
        text_encoder = TextEncoder(
            model_name=config.text_model_name,
            embed_dim=config.embed_dim
        )
        fusion = CrossAttentionFusion(
            dim=config.embed_dim,
            heads=config.num_heads
        )
        decoder = AnswerDecoder(
            vocab_size=vocab_size,
            embed_dim=config.embed_dim,
            max_len=config.max_len,
            num_layers=config.decoder_layers,
            nhead=config.num_heads
        )
        
        return cls(visual_encoder, text_encoder, fusion, decoder)

In [None]:
%%writefile vqamed/training.py
"""Training utilities for VQA Medical."""

from typing import Any

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


def train_epoch(
    model: torch.nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    tokenizer: Any,
    device: torch.device,
    max_len: int = 32,
    log_interval: int = 20
) -> float:
    """Train model for one epoch."""
    model.train()
    total_loss = 0.0

    for cnt, batch in enumerate(dataloader):
        if cnt % log_interval == 0:
            print(f'{cnt} / {len(dataloader)}')
            
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        answers = batch['answer']

        answer_tokens = tokenizer(
            answers,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors='pt'
        )
        decoder_input_ids = answer_tokens['input_ids'][:, :-1].to(device)
        labels = answer_tokens['input_ids'][:, 1:].to(device)

        logits = model(
            image=images,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids
        )

        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1),
            ignore_index=tokenizer.pad_token_id
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


@torch.no_grad()
def validate_epoch(
    model: torch.nn.Module,
    dataloader: DataLoader,
    tokenizer: Any,
    device: torch.device,
    max_len: int = 32,
    log_interval: int = 20
) -> float:
    """Validate model for one epoch."""
    model.eval()
    total_loss = 0.0

    for cnt, batch in enumerate(dataloader):
        if cnt % log_interval == 0:
            print(f"{cnt} / {len(dataloader)}")
            
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        answers = batch['answer']

        answer_tokens = tokenizer(
            answers,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors='pt'
        )

        decoder_input_ids = answer_tokens['input_ids'][:, :-1].to(device)
        labels = answer_tokens['input_ids'][:, 1:].to(device)

        logits = model(images, input_ids, attention_mask, decoder_input_ids)

        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1),
            ignore_index=tokenizer.pad_token_id
        )
        total_loss += loss.item()

    return total_loss / len(dataloader)


class EarlyStopping:
    """Early stopping handler."""
    
    def __init__(self, patience: int = 5, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.should_stop = False
    
    def __call__(self, val_loss: float) -> bool:
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
            return False

## 3. Mount Google Drive & Setup Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Option 1: If dataset is on Google Drive
# DATA_PATH = "/content/drive/MyDrive/datasets/ImageClef-2019-VQA-Med"

# Option 2: If using Kaggle dataset, download it first
# !pip install kaggle -q
# !mkdir -p ~/.kaggle
# !cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle datasets download -d ammar111/imageclef-2019-vqa-med
# !unzip imageclef-2019-vqa-med.zip -d data/

# Set your data path here:
DATA_PATH = "/content/drive/MyDrive/datasets/ImageClef-2019-VQA-Med"

## 4. Configuration

In [None]:
import torch
from vqamed import Config

# Configure paths and hyperparameters
config = Config(
    train_path=f"{DATA_PATH}/Training",
    validation_path=f"{DATA_PATH}/Validation",
    test_path=f"{DATA_PATH}/Test",
    save_path="/content/drive/MyDrive/vqa_med_checkpoints/best_model.pt",
    
    # Model params
    embed_dim=512,
    num_heads=8,
    decoder_layers=4,
    max_len=32,
    
    # Training params
    batch_size=32,
    num_epochs=30,
    learning_rate=1e-4,
    patience=5,
)

device = torch.device(config.device)
print(f"Device: {device}")
print(f"Train path: {config.train_images_dir}")

## 5. Create Datasets & DataLoaders

In [None]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from vqamed import VQADataset

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.text_model_name)
tokenizer.pad_token = '[PAD]'

# Create datasets
train_dataset = VQADataset(
    images_dir=config.train_images_dir,
    qa_file=config.train_qa_file,
    tokenizer=tokenizer,
    max_len=config.max_len
)

val_dataset = VQADataset(
    images_dir=config.val_images_dir,
    qa_file=config.val_qa_file,
    tokenizer=tokenizer,
    max_len=config.max_len
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## 6. Initialize Model

In [None]:
from vqamed import VQAModel

# Create model from config
model = VQAModel.from_config(config, vocab_size=len(tokenizer))
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Training Loop

In [None]:
from pathlib import Path
from tqdm.notebook import tqdm
from vqamed import train_epoch, validate_epoch, EarlyStopping

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

# Early stopping
early_stopping = EarlyStopping(patience=config.patience)

# Training history
train_losses = []
val_losses = []

print(f"Starting training for {config.num_epochs} epochs...")
print("-" * 50)

for epoch in tqdm(range(config.num_epochs), desc="Training"):
    train_loss = train_epoch(
        model, train_loader, optimizer, tokenizer, device, config.max_len
    )
    val_loss = validate_epoch(
        model, val_loader, tokenizer, device, config.max_len
    )
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"\nEpoch {epoch + 1}: Train Loss = {train_loss:.4f} | Val Loss = {val_loss:.4f}")
    
    # Early stopping check
    improved = early_stopping(val_loss)
    if improved:
        Path(config.save_path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), config.save_path)
        print("Model improved. Saving.")
    else:
        print(f"No improvement. Patience: {early_stopping.counter}/{config.patience}")
    
    if early_stopping.should_stop:
        print("Early stopping triggered.")
        break

print(f"\nTraining complete! Best model saved to: {config.save_path}")

## 8. Plot Training History

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
epochs = range(1, len(train_losses) + 1)

plt.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
plt.plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)

plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training vs Validation Loss', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 9. Load Best Model & Inference

In [None]:
# Load best model
model.load_state_dict(torch.load(config.save_path))
model.eval()
print("Best model loaded!")

In [None]:
@torch.no_grad()
def generate_answer(model, image, question, tokenizer, device, max_len=32):
    """Generate answer for a given image and question."""
    model.eval()
    
    # Tokenize question
    tokens = tokenizer(
        question,
        padding='max_length',
        truncation=True,
        max_length=max_len,
        return_tensors='pt'
    )
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)
    image = image.unsqueeze(0).to(device)
    
    # Start with [CLS] token
    decoder_input = torch.tensor([[tokenizer.cls_token_id]]).to(device)
    
    generated = []
    for _ in range(max_len):
        logits = model(image, input_ids, attention_mask, decoder_input)
        next_token = logits[:, -1, :].argmax(dim=-1)
        
        if next_token.item() == tokenizer.sep_token_id:
            break
            
        generated.append(next_token.item())
        decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
    
    return tokenizer.decode(generated, skip_special_tokens=True)

In [None]:
# Test on a sample
sample = val_dataset[0]
image = sample['image']
question = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
ground_truth = sample['answer']

predicted = generate_answer(model, image, question, tokenizer, device)

print(f"Question: {question}")
print(f"Ground Truth: {ground_truth}")
print(f"Predicted: {predicted}")

## 10. Download Model

In [None]:
from google.colab import files

# Download the model
files.download(config.save_path)