In [None]:
## üì¶ 1. Install Required Libraries
# !pip install datasets transformers torch torchvision pandas pillow scikit-learn tqdm

## üìö 2. Import Libraries
import os
import json
import re
import glob
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Transformers
from transformers import (
    ViTModel, ViTImageProcessor,
    AutoTokenizer, AutoModel,
    get_linear_schedule_with_warmup
)

# Computer Vision
from PIL import Image
import torchvision.transforms as T

# Sklearn
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report

# Device Selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üîß Device Used: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## üìä 3. Load and Analyze Data from Local Files
def load_local_vqa_rad_data(excel_path, image_folder_path, test_split=0.2, val_split=0.1, random_state=42):
    """
    Load VQA-RAD data from local Excel file and image folder
    """
    print("‚¨áÔ∏è Loading VQA-RAD data from local files...\n")
    
    # Check if files exist
    if not os.path.exists(excel_path):
        raise FileNotFoundError(f"‚ùå Excel file not found: {excel_path}")
    
    # Read Excel file
    df = pd.read_excel(excel_path)
    print(f"‚úÖ Excel file loaded successfully: {len(df)} rows")
    
    # Show first rows to understand structure
    print("\nüìã First rows of data:")
    print(df[['QUESTION', 'ANSWER', 'IMAGEID']].head())
    
    # Map column names
    if 'QUESTION' in df.columns and 'question' not in df.columns:
        df['question'] = df['QUESTION']
    
    if 'ANSWER' in df.columns and 'answer' not in df.columns:
        df['answer'] = df['ANSWER']
    
    if 'IMAGEID' in df.columns and 'image_name' not in df.columns:
        df['image_name'] = df['IMAGEID']
    
    # Prepare Dataset
    dataset = {'train': [], 'val': [], 'test': []}
    train_count = 0
    val_count = 0
    test_count = 0
    missing_images = 0
    found_images = 0
    
    # Shuffle data
    df = df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    # Get list of all image files in folder
    image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG', '.bmp', '.BMP', '.tiff', '.TIFF']
    all_image_files = []
    for ext in image_extensions:
        all_image_files.extend(glob.glob(os.path.join(image_folder_path, f"*{ext}")))
    
    # Create dictionary to map filenames without extension to full paths
    image_files_dict = {}
    for img_path in all_image_files:
        filename = os.path.basename(img_path)
        # Remove extension for matching
        filename_no_ext = os.path.splitext(filename)[0].lower()
        image_files_dict[filename_no_ext] = img_path
    
    print(f"\nüìÅ Found {len(all_image_files)} images in folder: {image_folder_path}")
    print(f"   Sample image names: {[os.path.basename(f) for f in all_image_files[:5]]}")
    
    # Process each row
    valid_rows = []
    for idx, row in df.iterrows():
        try:
            # Get image URL from IMAGEID
            image_url = str(row['IMAGEID'])
            
            # Extract image name (last part of URL)
            image_name = os.path.basename(image_url)
            image_name_no_ext = os.path.splitext(image_name)[0].lower()
            
            # Search for local image
            image_path = None
            
            # 1. Try direct matching
            if image_name_no_ext in image_files_dict:
                image_path = image_files_dict[image_name_no_ext]
            
            # 2. If not found, try partial matching
            if not image_path:
                for img_key, img_path in image_files_dict.items():
                    if image_name_no_ext in img_key or img_key in image_name_no_ext:
                        image_path = img_path
                        break
            
            if image_path:
                found_images += 1
            else:
                missing_images += 1
                print(f"‚ö†Ô∏è Could not find image for {image_name_no_ext} (Row {idx})")
                continue
            
            # Store example
            example = {
                'image': image_path,  # Will be the path
                'image_path': image_path,
                'question': str(row['question']) if 'question' in row else str(row['QUESTION']),
                'answer': str(row['answer']) if 'answer' in row else str(row['ANSWER']),
                'question_type': str(row.get('Q_TYPE', 'N/A')),
                'image_name': os.path.basename(image_path),
                'original_image_url': image_url
            }
            
            # Add additional metadata
            if 'QID_unique' in row:
                example['qid_unique'] = str(row['QID_unique'])
            if 'QID_linked' in row:
                example['qid_linked'] = str(row['QID_linked'])
            
            valid_rows.append(example)
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing row {idx}: {e}")
            continue
    
    # Split data into train, validation, and test
    from sklearn.model_selection import train_test_split
    
    train_data, temp_data = train_test_split(
        valid_rows, test_size=(test_split + val_split), random_state=random_state
    )
    
    val_data, test_data = train_test_split(
        temp_data, test_size=test_split/(test_split + val_split), random_state=random_state
    )
    
    dataset['train'] = train_data
    dataset['val'] = val_data
    dataset['test'] = test_data
    
    train_count = len(train_data)
    val_count = len(val_data)
    test_count = len(test_data)
    
    print(f"\nüìä Dataset Summary:")
    print(f"   - Training examples: {train_count}")
    print(f"   - Validation examples: {val_count}")
    print(f"   - Testing examples: {test_count}")
    print(f"   - Images found: {found_images}")
    print(f"   - Missing images: {missing_images}")
    
    if missing_images > 0:
        print(f"\n‚ö†Ô∏è Warning: {missing_images} images not found locally.")
        print("   Script looks for images with names like 'synpic54610' or just '54610'")
        print("   Make sure image filenames match synpic numbers in Excel file.")
    
    return dataset

# Define paths - update to match your local paths
EXCEL_PATH = "/content/VQA_RAD Dataset Public.xlsx"  # Path to your Excel file
IMAGE_FOLDER_PATH = "/content/VQA_RAD Image Folder"  # Path to your image folder

# Load data
try:
    dataset = load_local_vqa_rad_data(
        excel_path=EXCEL_PATH,
        image_folder_path=IMAGE_FOLDER_PATH,
        test_split=0.2,
        val_split=0.1
    )
    
    print("\n‚úÖ Data loaded successfully!")
    print(f"üìä Dataset Info:")
    print(f"   - Training examples count: {len(dataset['train'])}")
    print(f"   - Validation examples count: {len(dataset['val'])}")
    print(f"   - Testing examples count: {len(dataset['test'])}")
    print(f"   - Total examples: {sum(len(dataset[split]) for split in dataset)}")

    # Show example
    print("\nüìù Example from data:")
    example = dataset['train'][0]
    print(f"   Question: {example.get('question', 'N/A')}")
    print(f"   Answer: {example.get('answer', 'N/A')}")
    print(f"   Question type: {example.get('question_type', 'N/A')}")
    print(f"   Image path: {example.get('image_path', 'N/A')}")

except Exception as e:
    print(f"‚ùå Loading error: {e}")
    exit(1)

## üîç 4. Analyze Answer Distribution
# Collect all answers
all_answers = []
for split in ['train', 'val', 'test']:
    for example in dataset[split]:
        answer = example.get('answer', '')
        if answer:
            all_answers.append(str(answer).lower().strip())

# Analyze distribution
answer_counts = Counter(all_answers)
print(f"\nüìä Answer Statistics:")
print(f"   - Unique answers count: {len(answer_counts)}")
print(f"   - Top 10 most common answers:")
for answer, count in answer_counts.most_common(10):
    print(f"      '{answer}': {count} times")

# Select most common answers (TOP_K)
TOP_K_ANSWERS = 500  # Can be adjusted as needed
top_answers = [ans for ans, _ in answer_counts.most_common(TOP_K_ANSWERS)]
print(f"\n‚úÖ We will use the top {TOP_K_ANSWERS} most common answers")

## üèóÔ∏è 5. Build Advanced Model
class MultiHeadCrossAttention(nn.Module):
    """Multi-Head Cross Attention for integrating visual and textual features"""

    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads

        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Linear projections in batch from dim => num_heads x head_dim
        Q = self.query(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # Apply attention to values
        context = torch.matmul(attention, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)

        return self.out(context)


class AdvancedFusionModule(nn.Module):
    """Advanced fusion module using Cross Attention and Gating Mechanism"""

    def __init__(self, vision_dim=768, text_dim=768, hidden_dim=512, num_heads=8, dropout=0.2):
        super().__init__()

        # Projection layers
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)
        self.text_proj = nn.Linear(text_dim, hidden_dim)

        # Cross Attention: Text attends to Vision
        self.text_to_vision_attention = MultiHeadCrossAttention(hidden_dim, num_heads, dropout)

        # Cross Attention: Vision attends to Text
        self.vision_to_text_attention = MultiHeadCrossAttention(hidden_dim, num_heads, dropout)

        # Gating mechanism
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # Layer Normalization
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, vision_features, text_features):
        # Project to same dimension
        vision = self.vision_proj(vision_features)  # [batch, seq_v, hidden]
        text = self.text_proj(text_features)        # [batch, seq_t, hidden]

        # Cross Attention
        text_attended = self.text_to_vision_attention(text, vision, vision)
        text = self.norm1(text + self.dropout(text_attended))

        vision_attended = self.vision_to_text_attention(vision, text, text)
        vision = self.norm2(vision + self.dropout(vision_attended))

        # Pooling: mean over sequence dimension
        vision_pooled = vision.mean(dim=1)  # [batch, hidden]
        text_pooled = text.mean(dim=1)      # [batch, hidden]

        # Gating mechanism for adaptive fusion
        concat_features = torch.cat([vision_pooled, text_pooled], dim=-1)
        gate = self.gate(concat_features)

        # Gated fusion
        fused = gate * vision_pooled + (1 - gate) * text_pooled
        fused = self.norm3(fused)

        # Feed-forward
        fused = fused + self.ffn(fused)

        return fused


class VQAModel(nn.Module):
    """Main Model: ViT + BioBERT + Advanced Fusion"""

    def __init__(self, num_classes, hidden_dim=512, num_heads=8, dropout=0.2):
        super().__init__()

        # Vision Encoder: ViT
        self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.vision_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
        vision_dim = self.vision_encoder.config.hidden_size  # 768

        # Text Encoder: BioBERT
        self.text_encoder = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        self.tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        text_dim = self.text_encoder.config.hidden_size  # 768

        # Fusion Module
        self.fusion = AdvancedFusionModule(
            vision_dim=vision_dim,
            text_dim=text_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout
        )

        # Classification Head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, images, input_ids, attention_mask):
        # Extract vision features
        vision_outputs = self.vision_encoder(pixel_values=images)
        vision_features = vision_outputs.last_hidden_state  # [batch, 197, 768]

        # Extract text features
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state  # [batch, seq_len, 768]

        # Fuse features
        fused_features = self.fusion(vision_features, text_features)

        # Classify
        logits = self.classifier(fused_features)

        return logits


print("\n‚úÖ Model built successfully!")
print("\nüìê Model Components:")
print("   1. Vision Transformer (ViT) - google/vit-base-patch16-224-in21k")
print("   2. BioBERT - dmis-lab/biobert-base-cased-v1.2")
print("   3. Multi-Head Cross Attention Fusion")
print("   4. Deep Classification Head")

## üì¶ 6. Prepare Dataset and DataLoader
class VQARadDataset(Dataset):
    """Custom Dataset for VQA-RAD data"""

    def __init__(self, data, vision_processor, tokenizer, label_encoder, split='train', max_length=128):
        self.data = data
        self.vision_processor = vision_processor
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.max_length = max_length
        self.split = split

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]

        # Process image - load image from path
        try:
            image = Image.open(example['image']).convert('RGB')
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load image {example['image']}: {e}")
            # Create white image as fallback
            image = Image.new('RGB', (224, 224), color='white')
        
        # Process image using Vision Processor
        image_inputs = self.vision_processor(images=image, return_tensors='pt')
        pixel_values = image_inputs['pixel_values'].squeeze(0)

        # Process question
        question = str(example.get('question', ''))
        text_inputs = self.tokenizer(
            question,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Process answer
        answer = str(example.get('answer', '')).lower().strip()

        # Encode label
        if answer in self.label_encoder.classes_:
            label = self.label_encoder.transform([answer])[0]
        else:
            label = -1  # Unknown answer

        return {
            'pixel_values': pixel_values,  # already tensor
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long),
            'answer_text': answer,
            'question_text': question,
            'image_path': example.get('image_path', '')
        }

# Custom function to collate data in batch
def custom_collate_fn(batch):
    """Custom function to collate data in batch"""
    # Aggregate data by keys
    collated_batch = {}
    
    # Collect all keys present in batch
    keys = batch[0].keys()
    
    for key in keys:
        if key in ['answer_text', 'question_text', 'image_path']:
            # For text fields, collect in list
            collated_batch[key] = [item[key] for item in batch]
        else:
            # For other fields, aggregate into tensor
            collated_batch[key] = torch.stack([item[key] for item in batch])
    
    return collated_batch

# Prepare Label Encoder
label_encoder = LabelEncoder()
label_encoder.fit(top_answers)

num_classes = len(label_encoder.classes_)
print(f"\nüìä Number of Classes: {num_classes}")

# Initialize model
model = VQAModel(num_classes=num_classes, hidden_dim=512, num_heads=8, dropout=0.2)
model = model.to(device)

# Prepare Datasets
train_dataset = VQARadDataset(
    dataset['train'],
    model.vision_processor,
    model.tokenizer,
    label_encoder,
    split='train'
)

val_dataset = VQARadDataset(
    dataset['val'],
    model.vision_processor,
    model.tokenizer,
    label_encoder,
    split='val'
)

test_dataset = VQARadDataset(
    dataset['test'],
    model.vision_processor,
    model.tokenizer,
    label_encoder,
    split='test'
)

# DataLoaders
BATCH_SIZE = 8  # Can be adjusted based on available memory

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Use 0 to avoid multiprocessing issues
    pin_memory=True,
    drop_last=True,
    collate_fn=custom_collate_fn  # Add custom collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Use 0 to avoid multiprocessing issues
    pin_memory=True,
    collate_fn=custom_collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Use 0 to avoid multiprocessing issues
    pin_memory=True,
    collate_fn=custom_collate_fn
)

print(f"\n‚úÖ Datasets and DataLoaders prepared")
print(f"   - Training batch size: {BATCH_SIZE}")
print(f"   - Training batches count: {len(train_loader)}")
print(f"   - Validation batches count: {len(val_loader)}")
print(f"   - Testing batches count: {len(test_loader)}")

## üéØ 7. Training Setup
# Training parameters
NUM_EPOCHS = 15
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 100

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999)
)

# Learning rate scheduler
num_training_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=num_training_steps
)

# Loss function with label smoothing
criterion = nn.CrossEntropyLoss(ignore_index=-1, label_smoothing=0.1)

print("\n‚úÖ Training components prepared")
print(f"   - Number of Epochs: {NUM_EPOCHS}")
print(f"   - Learning Rate: {LEARNING_RATE}")
print(f"   - Warmup Steps: {WARMUP_STEPS}")
print(f"   - Total training steps: {num_training_steps}")

## üöÄ 8. Training and Evaluation Functions
def train_epoch(model, dataloader, optimizer, scheduler, criterion, device, epoch):
    """Train one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} - Training")

    for batch in progress_bar:
        # Move to device
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Filter out unknown labels
        valid_mask = labels != -1
        if not valid_mask.any():
            continue

        # Forward pass
        optimizer.zero_grad()
        logits = model(pixel_values, input_ids, attention_mask)

        # Calculate loss
        loss = criterion(logits[valid_mask], labels[valid_mask])

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        # Calculate accuracy
        predictions = torch.argmax(logits[valid_mask], dim=-1)
        correct += (predictions == labels[valid_mask]).sum().item()
        total += valid_mask.sum().item()

        total_loss += loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    accuracy = correct / total if total > 0 else 0

    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device, label_encoder):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    all_answer_texts = []

    progress_bar = tqdm(dataloader, desc="Evaluating")

    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Filter out unknown labels
            valid_mask = labels != -1
            if not valid_mask.any():
                continue

            # Forward pass
            logits = model(pixel_values, input_ids, attention_mask)
            loss = criterion(logits[valid_mask], labels[valid_mask])

            total_loss += loss.item()

            # Get predictions
            predictions = torch.argmax(logits[valid_mask], dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels[valid_mask].cpu().numpy())

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    
    if len(all_labels) > 0:
        accuracy = accuracy_score(all_labels, all_predictions)
        f1 = f1_score(all_labels, all_predictions, average='weighted')
    else:
        accuracy = 0
        f1 = 0

    return avg_loss, accuracy, f1, all_predictions, all_labels


print("\n‚úÖ Training and evaluation functions prepared")

## üèãÔ∏è 9. Start Training
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [], 
    'val_acc': [],
    'val_f1': []
}

best_val_acc = 0
best_model_path = 'best_vqa_model_vit_biobert.pth'

print("\nüèãÔ∏è Starting training...\n")
print("="*70)

for epoch in range(NUM_EPOCHS):
    # Training
    train_loss, train_acc = train_epoch(
        model, train_loader, optimizer, scheduler, criterion, device, epoch
    )

    # Validation
    val_loss, val_acc, val_f1, _, _ = evaluate(
        model, val_loader, criterion, device, label_encoder
    )

    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)

    # Print results
    print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS} Results:")
    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"   Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}% | Val F1: {val_f1:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_f1': val_f1,
            'label_encoder': label_encoder,
            'history': history,
            'model_config': {
                'num_classes': num_classes,
                'hidden_dim': 512,
                'num_heads': 8,
                'dropout': 0.2
            }
        }, best_model_path)
        print(f"   ‚úÖ Best model saved! (Val Acc: {val_acc*100:.2f}%)")

    print("="*70)

print("\nüéâ Training completed!")
print(f"‚úÖ Best validation accuracy: {best_val_acc*100:.2f}%")

## üß™ 10. Evaluate Model on Test Data
print("\nüß™ Evaluating model on test data...")

# Load best model
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Evaluate on test data
test_loss, test_acc, test_f1, test_predictions, test_labels = evaluate(
    model, test_loader, criterion, device, label_encoder
)

print(f"üìä Test Results:")
print(f"   Test Loss: {test_loss:.4f}")
print(f"   Test Accuracy: {test_acc*100:.2f}%")
print(f"   Test F1-Score: {test_f1:.4f}")

## üìà 11. Show Training Results
def predict_answer(image_path, question, model, label_encoder, device):
    """
    Function to predict answer for a given image and question
    """
    # Load image
    try:
        image = Image.open(image_path).convert('RGB')
    except Exception as e:
        print(f"‚ùå Error loading image: {e}")
        return "Unable to load image"
    
    # Process image
    vision_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
    image_inputs = vision_processor(images=image, return_tensors='pt')
    pixel_values = image_inputs['pixel_values'].to(device)
    
    # Process question
    tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
    text_inputs = tokenizer(
        question,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    input_ids = text_inputs['input_ids'].to(device)
    attention_mask = text_inputs['attention_mask'].to(device)
    
    # Prediction
    with torch.no_grad():
        logits = model(pixel_values, input_ids, attention_mask)
        prediction = torch.argmax(logits, dim=-1)
    
    # Convert number to text answer
    try:
        answer_text = label_encoder.inverse_transform([prediction.cpu().item()])[0]
    except:
        answer_text = "Unknown answer"
    
    return answer_text

# Test function on example from test data
if len(dataset['test']) > 0:
    example = dataset['test'][0]
    image_path = example['image_path']
    question = example['question']
    
    print(f"\nüîç Testing prediction function:")
    print(f"   Question: {question}")
    print(f"   Image path: {image_path}")
    
    answer = predict_answer(image_path, question, model, label_encoder, device)
    print(f"   Predicted answer: {answer}")
    print(f"   Actual answer: {example['answer']}")

## üíæ 12. Save Final Model
final_model_path = 'final_vqa_model_vit_biobert.pth'
torch.save({
    'epoch': NUM_EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'test_accuracy': test_acc,
    'test_f1': test_f1,
    'label_encoder': label_encoder,
    'history': history,
    'model_config': {
        'num_classes': num_classes,
        'hidden_dim': 512,
        'num_heads': 8,
        'dropout': 0.2
    },
    'training_info': {
        'learning_rate': LEARNING_RATE,
        'batch_size': BATCH_SIZE,
        'num_epochs': NUM_EPOCHS,
        'best_val_acc': best_val_acc
    }
}, final_model_path)

print(f"\nüíæ Models saved:")
print(f"   - {best_model_path} (best model)")
print(f"   - {final_model_path} (final model)")
print(f"\nüéØ Processing and training completed successfully!")