In [None]:
"""
Phi-3 Vision Analysis Pipeline for Google Colab
Handles benchmark replication, domain adaptation, and tuning strategies
Compatible with Colab's GPU limitations (no flash_attn2)
"""

# ============================================================================
# SECTION 1: Environment Setup and Installation
# ============================================================================

# Install required packages
!pip install -q transformers
!pip install -q accelerate
!pip install -q datasets
!pip install -q pillow
!pip install -q torch
!pip install -q peft
!pip install -q bitsandbytes
!pip install -q evaluate
!pip install -q scikit-learn
!pip install -q tqdm
!pip install -q pycocotools
!pip install -q nltk
!pip install -q rouge-score
!pip install -q bert-score
!pip install -q matplotlib
!pip install -q seaborn
!pip install -q pandas


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import json
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from tqdm.auto import tqdm
import gc
from collections import defaultdict
import copy

from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer


In [None]:

# Download NLTK data
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 10

# Check GPU
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


GPU Available: True
GPU Name: Tesla T4
GPU Memory: 15.83 GB


In [None]:

# ============================================================================
# SECTION 2: Configuration Classes
# ============================================================================

@dataclass
class ModelConfig:
    """Configuration for Phi-3 Vision model"""
    model_name: str = "microsoft/Phi-3-vision-128k-instruct"
    use_quantization: bool = True
    quantization_bits: int = 4
    device_map: str = "auto"
    # Use float32 by default to avoid fp16 mask underflow causing device asserts
    torch_dtype: torch.dtype = torch.float32
    trust_remote_code: bool = True
    attn_implementation: str = "eager"  # Colab-compatible


@dataclass
class FineTuningConfig:
    """Configuration for feature-based fine-tuning"""
    # Training hyperparameters
    num_epochs: int = 5
    batch_size: int = 2
    learning_rate: float = 1e-4
    warmup_steps: int = 100
    gradient_accumulation_steps: int = 4
    max_grad_norm: float = 1.0
    weight_decay: float = 0.01

    # Feature adapter configuration
    adapter_hidden_dim: int = 256
    adapter_dropout: float = 0.1
    adapter_bottleneck_dim: int = 64
    num_adapter_layers: int = 2

    # Output directories
    output_dir: str = "./phi3_finetuning_outputs"
    checkpoints_dir: str = "./phi3_finetuning_outputs/checkpoints"
    visualizations_dir: str = "./phi3_finetuning_outputs/visualizations"

    # Evaluation settings
    eval_steps: int = 50
    save_steps: int = 100
    logging_steps: int = 10
    max_samples_train: Optional[int] = 50
    max_samples_eval: Optional[int] = 10

# ============================================================================
# SECTION 3: Feature Adapter Architecture
# ============================================================================

class FeatureAdapter(nn.Module):
    """
    Feature adapter layer inspired by CLIP-Adapter
    Adds lightweight trainable layers to frozen vision-language features
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        bottleneck_dim: int = 64,
        dropout: float = 0.1,
        num_layers: int = 2
    ):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.bottleneck_dim = bottleneck_dim

        # Bottleneck architecture for efficiency
        layers = []

        # Down-projection
        layers.append(nn.Linear(input_dim, bottleneck_dim))
        layers.append(nn.LayerNorm(bottleneck_dim))
        layers.append(nn.GELU())
        layers.append(nn.Dropout(dropout))

        # Middle layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(bottleneck_dim, bottleneck_dim))
            layers.append(nn.LayerNorm(bottleneck_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))

        # Up-projection
        layers.append(nn.Linear(bottleneck_dim, input_dim))
        layers.append(nn.Dropout(dropout))

        self.adapter = nn.Sequential(*layers)

        # Learnable scaling factor (alpha)
        self.alpha = nn.Parameter(torch.ones(1) * 0.1)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize adapter weights"""
        for module in self.adapter.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

    def forward(self, x):
        """
        Forward pass with residual connection
        x: input features [batch_size, seq_len, hidden_dim]
        """
        # Adapter transformation
        adapter_output = self.adapter(x)

        # Residual connection with learnable scaling
        output = x + self.alpha * adapter_output

        return output

class Phi3WithFeatureAdapters(nn.Module):
    """
    Phi-3 Vision model with feature adapters
    Adds adapters to vision and language feature spaces
    """

    def __init__(self, base_model, config: FineTuningConfig):
        super().__init__()

        self.base_model = base_model
        self.config = config

        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Get hidden dimension from model config
        self.hidden_dim = base_model.config.hidden_size

        # Create feature adapters for different layers
        self.vision_adapter = FeatureAdapter(
            input_dim=self.hidden_dim,
            hidden_dim=config.adapter_hidden_dim,
            bottleneck_dim=config.adapter_bottleneck_dim,
            dropout=config.adapter_dropout,
            num_layers=config.num_adapter_layers
        )

        self.language_adapter = FeatureAdapter(
            input_dim=self.hidden_dim,
            hidden_dim=config.adapter_hidden_dim,
            bottleneck_dim=config.adapter_bottleneck_dim,
            dropout=config.adapter_dropout,
            num_layers=config.num_adapter_layers
        )

        # Cross-modal fusion adapter
        self.fusion_adapter = FeatureAdapter(
            input_dim=self.hidden_dim,
            hidden_dim=config.adapter_hidden_dim,
            bottleneck_dim=config.adapter_bottleneck_dim,
            dropout=config.adapter_dropout,
            num_layers=config.num_adapter_layers
        )

        print(f"\n✓ Feature adapters initialized")
        print(f"  - Vision adapter: {sum(p.numel() for p in self.vision_adapter.parameters())} params")
        print(f"  - Language adapter: {sum(p.numel() for p in self.language_adapter.parameters())} params")
        print(f"  - Fusion adapter: {sum(p.numel() for p in self.fusion_adapter.parameters())} params")
        print(f"  - Total trainable: {self.count_trainable_parameters()} params")

    def count_trainable_parameters(self):
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, **kwargs):
      """Forward pass with adapter integration"""
      # Request hidden states from base model
      kwargs['output_hidden_states'] = True
      outputs = self.base_model(**kwargs)

      # Determine hidden states: prefer .last_hidden_state (common name) or hidden_states[-1]
      if hasattr(outputs, 'last_hidden_state') and outputs.last_hidden_state is not None:
          hidden_states = outputs.last_hidden_state  # [batch, seq_len, hidden_dim]
      else:
          # fallback to transformer hidden states tuple
          hidden_states = outputs.hidden_states[-1]

      # Ensure dtype for adapter computation is safe (use float32)
      adapter_input = hidden_states
      if adapter_input.dtype != torch.float32:
          adapter_input = adapter_input.to(torch.float32)

      adapted_hidden_states = self.fusion_adapter(adapter_input)

      # Cast back to original dtype if needed
      if adapted_hidden_states.dtype != hidden_states.dtype:
          adapted_hidden_states = adapted_hidden_states.to(hidden_states.dtype)

      # Place adapted hidden states into outputs so loss/generation uses them
      try:
          outputs.last_hidden_state = adapted_hidden_states
      except Exception:
          # If attribute not writable, create a new output object by copying dict
          out_dict = outputs.to_dict()
          out_dict['last_hidden_state'] = adapted_hidden_states
          # recreate a BaseModelOutput-like object using the model's output class if possible
          # (best-effort; many HuggingFace outputs accept direct construction)
          from transformers.modeling_outputs import BaseModelOutput
          outputs = BaseModelOutput(**out_dict)

      # also fix hidden_states tuple if present
      if hasattr(outputs, 'hidden_states') and outputs.hidden_states:
          hs_list = list(outputs.hidden_states)
          hs_list[-1] = adapted_hidden_states
          outputs.hidden_states = tuple(hs_list)

      adapted_hidden_states = self.fusion_adapter(hidden_states)

      # Replace last hidden state
      outputs.hidden_states = list(outputs.hidden_states)
      outputs.hidden_states[-1] = adapted_hidden_states

      return outputs



    def generate(self, **kwargs):
        """Generation method for inference"""
        return self.base_model.generate(**kwargs)

    def save_adapters(self, save_path: str):
        """Save only adapter weights"""
        adapter_state = {
            'vision_adapter': self.vision_adapter.state_dict(),
            'language_adapter': self.language_adapter.state_dict(),
            'fusion_adapter': self.fusion_adapter.state_dict(),
        }
        torch.save(adapter_state, save_path)
        print(f"Adapters saved to: {save_path}")

    def load_adapters(self, load_path: str):
        """Load adapter weights"""
        adapter_state = torch.load(load_path)
        self.vision_adapter.load_state_dict(adapter_state['vision_adapter'])
        self.language_adapter.load_state_dict(adapter_state['language_adapter'])
        self.fusion_adapter.load_state_dict(adapter_state['fusion_adapter'])
        print(f"Adapters loaded from: {load_path}")



In [None]:
# ============================================================================
# SECTION 4: Model Loader (Reusing from original code)
# ============================================================================

class Phi3VisionLoader:
    """Handles model loading with Colab-specific configurations"""

    @staticmethod
    def load_model_and_processor(config: ModelConfig):
        """Load Phi-3 Vision with Colab-compatible settings"""
        print("Loading Phi-3 Vision model...")

        quantization_config = None
        if config.use_quantization:
            quantization_config = BitsAndBytesConfig(
              load_in_4bit=True,
              bnb_4bit_quant_type="nf4",
              bnb_4bit_compute_dtype=torch.float32,
              bnb_4bit_use_double_quant=True,
              llm_int8_enable_fp32_cpu_offload=True, # Added to allow CPU offload
          )


        processor = AutoProcessor.from_pretrained(
            config.model_name,
            trust_remote_code=config.trust_remote_code
        )

        model = AutoModelForCausalLM.from_pretrained(
            config.model_name,
            device_map=config.device_map,
            torch_dtype=config.torch_dtype,
            trust_remote_code=config.trust_remote_code,
            quantization_config=quantization_config,
            attn_implementation=config.attn_implementation,
            _attn_implementation=config.attn_implementation
        )

        model.config.use_cache = False

        print(f"✓ Model loaded on {next(model.parameters()).device}")
        print(f"✓ Model memory: {model.get_memory_footprint() / 1e9:.2f} GB")

        return model, processor

# ============================================================================
# SECTION 5: Dataset Handlers (Reusing from original code)
# ============================================================================

class DatasetHandler:
    """Unified dataset handler for training"""

    @staticmethod
    def load_training_dataset(dataset_name: str = "DocVQA", max_samples: Optional[int] = None):
        """Load dataset for fine-tuning"""
        print(f"Loading {dataset_name} dataset for training...")

        if dataset_name == "DocVQA":
            dataset = load_dataset("nielsr/docvqa_1200_examples_donut", split="train")
        elif dataset_name == "ScienceQA":
            dataset = load_dataset("derek-thomas/ScienceQA", split="train")
            dataset = dataset.filter(lambda x: x['image'] is not None)
        elif dataset_name == "COCO":
            dataset = load_dataset("HuggingFaceM4/COCO", split="train")
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        if max_samples:
            dataset = dataset.select(range(min(max_samples, len(dataset))))

        print(f"✓ Loaded {len(dataset)} training samples")
        return dataset

    @staticmethod
    def prepare_sample(sample, processor, dataset_name: str):
        """Prepare sample based on dataset type"""
        try:
            image = sample['image']
            if not isinstance(image, Image.Image):
                image = Image.open(image).convert('RGB')

            if dataset_name == "DocVQA":
                question = sample['query']['en'] if isinstance(sample['query'], dict) else sample['query']
                answer = sample.get('answers', [''])[0] if 'answers' in sample else ''

                prompt = f"""<|user|>
    <|image_1|>
    Read the document and answer: {question}<|end|>
    <|assistant|>
    {answer}<|end|>"""

            elif dataset_name == "ScienceQA":
                question = sample['question']
                choices = sample['choices']
                answer_idx = sample['answer']
                answer = str(answer_idx)

                choices_text = "\n".join([f"{i}. {choice}" for i, choice in enumerate(choices)])
                prompt = f"""<|user|>
    <|image_1|>
    Question: {question}
    Choices:
    {choices_text}
    Answer (0-{len(choices)-1}):<|end|>
    <|assistant|>
    {answer}<|end|>"""

            elif dataset_name == "COCO":
                captions = sample.get('sentences', {}).get('raw', [''])
                caption = captions[0] if captions else ''

                prompt = f"""<|user|>
    <|image_1|>
    Describe this image:<|end|>
    <|assistant|>
    {caption}<|end|>"""

            # Tokenize
            inputs = processor(
                text=prompt,
                images=image,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )

            # Create labels
            inputs['labels'] = inputs['input_ids'].clone()

            # pad token
            pad_id = processor.tokenizer.pad_token_id
            if pad_id is None:
                pad_id = processor.tokenizer.eos_token_id
                processor.tokenizer.pad_token = processor.tokenizer.eos_token

            # ─────────────────────────────────────────────
            # FIX 1: Replace invalid negative IDs
            # ─────────────────────────────────────────────
            input_ids = inputs['input_ids']
            input_ids[input_ids < 0] = pad_id

            # ─────────────────────────────────────────────
            # FIX 2: Handle IDs >= embedding vocab size
            # use UNK token to safely clip
            # ─────────────────────────────────────────────
            actual_vocab_size = len(processor.tokenizer)
            unk_id = processor.tokenizer.unk_token_id

            # clamp overflow tokens
            input_ids[input_ids >= actual_vocab_size] = unk_id

            # Write back
            inputs['input_ids'] = input_ids

            # Update labels (mask padding)
            attention = inputs['attention_mask']
            labels = inputs['labels']
            labels = labels.masked_fill(attention == 0, -100)
            inputs['labels'] = labels

            # Final sanity check
            min_id = int(inputs['input_ids'].min().item())
            max_id = int(inputs['input_ids'].max().item())

            if min_id < 0 or max_id >= actual_vocab_size:
                raise ValueError(
                    f"Token ids out of range after fix: min={min_id}, max={max_id}, actual_vocab_size={actual_vocab_size}"
                )

            return inputs, (answer if dataset_name != "COCO" else caption)

        except Exception as e:
            print(f"Error preparing sample: {e}")
            return None, None


In [None]:
# ============================================================================
# SECTION 6: Metrics Calculator (Simplified from original)
# ============================================================================

class MetricsCalculator:
    """Calculate evaluation metrics"""

    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.smoothing = SmoothingFunction().method1

    def calculate_bleu(self, predictions: List[str], references: List[str]) -> float:
        """Calculate BLEU-4 score"""
        scores = []
        for pred, ref in zip(predictions, references):
            pred_tokens = pred.lower().split()
            ref_tokens = [ref.lower().split()]
            score = sentence_bleu(ref_tokens, pred_tokens, weights=(0.25, 0.25, 0.25, 0.25),
                                smoothing_function=self.smoothing)
            scores.append(score)
        return np.mean(scores) * 100

    def calculate_rouge(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
        """Calculate ROUGE scores"""
        rouge_scores = defaultdict(list)
        for pred, ref in zip(predictions, references):
            scores = self.rouge_scorer.score(ref, pred)
            for key, value in scores.items():
                rouge_scores[f'{key}_f1'].append(value.fmeasure * 100)
        return {k: np.mean(v) for k, v in rouge_scores.items()}

    def calculate_exact_match(self, predictions: List[str], references: List[str]) -> float:
        """Calculate exact match accuracy"""
        matches = sum(1 for pred, ref in zip(predictions, references)
                     if pred.strip().lower() == ref.strip().lower())
        return (matches / len(predictions)) * 100 if predictions else 0.0

    def calculate_all_metrics(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
        """Calculate all metrics"""
        metrics = {
            'bleu4': self.calculate_bleu(predictions, references),
            'exact_match': self.calculate_exact_match(predictions, references)
        }
        metrics.update(self.calculate_rouge(predictions, references))
        return metrics



In [None]:
def validate_batch_for_forward(batch: Dict[str, torch.Tensor], processor) -> None:
    """
    Basic sanity checks for tensors that commonly cause CUDA asserts:
      - shapes of input_ids and attention_mask match
      - labels are within [ -100, vocab_size-1 ]
      - no negative/too-large token ids
    Raises ValueError with descriptive message if something is wrong.
    """
    vocab_size = len(processor.tokenizer)
    device_info = {k: (v.device if torch.is_tensor(v) else None) for k, v in batch.items()}

    # shape checks
    if 'input_ids' in batch and 'attention_mask' in batch:
        if batch['input_ids'].shape != batch['attention_mask'].shape:
            raise ValueError(f"Shape mismatch: input_ids {batch['input_ids'].shape} vs attention_mask {batch['attention_mask'].shape}")

    # token id ranges
    if 'input_ids' in batch:
        min_id = int(batch['input_ids'].min().cpu().item())
        max_id = int(batch['input_ids'].max().cpu().item())
        if min_id < 0 or max_id >= vocab_size:
            raise ValueError(f"input_ids out of range: min={min_id}, max={max_id}, vocab_size={vocab_size}")

    # labels (allow -100 for ignored positions)
    if 'labels' in batch:
        labels = batch['labels']
        # check -100 allowed
        min_label = int(labels.min().cpu().item())
        max_label = int(labels.max().cpu().item())
        if min_label < -100 or max_label >= vocab_size:
            raise ValueError(f"labels out of range: min={min_label}, max={max_label}, vocab_size={vocab_size}")
    print("Batch devices:", device_info)


In [None]:
# ============================================================================
# SECTION 7: Trainer Class for Feature Adapters
# ============================================================================

class FeatureAdapterTrainer:
    """Trainer for feature-based fine-tuning"""

    def __init__(
        self,
        model: Phi3WithFeatureAdapters,
        processor,
        config: FineTuningConfig,
        train_dataset,
        eval_dataset,
        dataset_name: str
    ):
        self.model = model
        self.processor = processor
        self.config = config
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.dataset_name = dataset_name
        self.device = next(model.parameters()).device

        # Training state
        self.global_step = 0
        self.best_eval_loss = float('inf')
        self.training_history = []

        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Setup scheduler
        total_steps = (len(train_dataset) // config.batch_size) * config.num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=total_steps
        )

        # Create directories
        os.makedirs(config.output_dir, exist_ok=True)
        os.makedirs(config.checkpoints_dir, exist_ok=True)
        os.makedirs(config.visualizations_dir, exist_ok=True)

        print(f"\n✓ Trainer initialized")
        print(f"  - Total training steps: {total_steps}")
        print(f"  - Warmup steps: {config.warmup_steps}")
        print(f"  - Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")

    def train_epoch(self, epoch: int):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0
        num_batches = 0

        progress_bar = tqdm(range(len(self.train_dataset)), desc=f"Epoch {epoch+1}")

        for idx in range(0, len(self.train_dataset), self.config.batch_size):
            # Get batch
            batch_samples = []
            for i in range(idx, min(idx + self.config.batch_size, len(self.train_dataset))):
                inputs, _ = DatasetHandler.prepare_sample(
                    self.train_dataset[i],
                    self.processor,
                    self.dataset_name
                )
                if inputs is not None:
                    batch_samples.append(inputs)

            if not batch_samples:
                continue

            # Collate batch
            batch = self._collate_batch(batch_samples)
            batch = {k: v.to(self.device) for k, v in batch.items()}

            # ---- defensive check before forward to catch issues on CPU/early ----
            try:
                # perform validation on CPU tensors before moving to cuda (optional)
                # If processor lives on CPU, validate on CPU copy - but we've already moved to device.
                # If you want validation strictly on CPU, call validate before .to(device)
                validate_batch_for_forward(batch, self.processor)
            except Exception as ve:
                print("Batch validation failed:", ve)
                # skip this batch (or raise) to avoid CUDA assert
                continue

            # Forward pass (now safer)
            outputs = self.model(**batch)

            loss = outputs.loss

            # Normalize loss
            loss = loss / self.config.gradient_accumulation_steps
            loss.backward()

            epoch_loss += loss.item()

            # Update weights
            if (num_batches + 1) % self.config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    [p for p in self.model.parameters() if p.requires_grad],
                    self.config.max_grad_norm
                )
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                self.global_step += 1

                # Logging
                if self.global_step % self.config.logging_steps == 0:
                    lr = self.scheduler.get_last_lr()[0]
                    print(f"\nStep {self.global_step} | Loss: {loss.item():.4f} | LR: {lr:.2e}")

            num_batches += 1
            progress_bar.update(self.config.batch_size)

            # Periodic evaluation
            if self.global_step % self.config.eval_steps == 0:
                eval_metrics = self.evaluate()
                self.training_history.append({
                    'step': self.global_step,
                    'epoch': epoch,
                    'train_loss': epoch_loss / max(num_batches, 1),
                    **eval_metrics
                })
                self.model.train()

            # Save checkpoint
            if self.global_step % self.config.save_steps == 0:
                self.save_checkpoint(f"checkpoint_step_{self.global_step}")

            # Clear memory
            del batch, outputs, loss
            if num_batches % 10 == 0:
                torch.cuda.empty_cache()

        progress_bar.close()
        avg_loss = epoch_loss / max(num_batches, 1)
        print(f"\nEpoch {epoch+1} completed | Avg Loss: {avg_loss:.4f}")

        return avg_loss

    def _collate_batch(self, batch_samples):
        """Collate samples into batch"""
        # Simple collation - pad to max length in batch
        max_length = max(s['input_ids'].shape[1] for s in batch_samples)

        input_ids = []
        attention_mask = []
        pixel_values = []
        labels = []

        for sample in batch_samples:
            # Pad input_ids
            pad_length = max_length - sample['input_ids'].shape[1]
            input_ids.append(F.pad(sample['input_ids'], (0, pad_length), value=self.processor.tokenizer.pad_token_id))
            attention_mask.append(F.pad(sample['attention_mask'], (0, pad_length), value=0))
            labels.append(F.pad(sample['labels'], (0, pad_length), value=-100))
            pixel_values.append(sample['pixel_values'])

        return {
            'input_ids': torch.cat(input_ids, dim=0),
            'attention_mask': torch.cat(attention_mask, dim=0),
            'pixel_values': torch.cat(pixel_values, dim=0),
            'labels': torch.cat(labels, dim=0)
        }

    def evaluate(self):
        """Evaluate model"""
        self.model.eval()
        eval_loss = 0
        predictions = []
        references = []

        print("\nRunning evaluation...")

        with torch.no_grad():
            for idx in tqdm(range(len(self.eval_dataset)), desc="Evaluating"):
                inputs, reference = DatasetHandler.prepare_sample(
                    self.eval_dataset[idx],
                    self.processor,
                    self.dataset_name
                )

                if inputs is None:
                    continue

                # Move to device
                inputs_eval = {k: v.to(self.device) for k, v in inputs.items()}

                # Calculate loss
                outputs = self.model(**inputs_eval)
                eval_loss += outputs.loss.item()

                # Generate prediction
                gen_inputs = {k: v for k, v in inputs_eval.items() if k != 'labels'}
                pred_ids = self.model.generate(**gen_inputs, max_new_tokens=100)
                pred_text = self.processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

                if "<|assistant|>" in pred_text:
                    pred_text = pred_text.split("<|assistant|>")[-1].strip()

                predictions.append(pred_text)
                references.append(reference)

                if idx % 10 == 0:
                    torch.cuda.empty_cache()

        # Calculate metrics
        metrics_calc = MetricsCalculator()
        metrics = metrics_calc.calculate_all_metrics(predictions, references)
        metrics['eval_loss'] = eval_loss / len(self.eval_dataset)

        print(f"\n{'='*50}")
        print("Evaluation Results:")
        print(f"  Loss: {metrics['eval_loss']:.4f}")
        print(f"  BLEU-4: {metrics['bleu4']:.2f}")
        print(f"  ROUGE-L: {metrics['rougeL_f1']:.2f}")
        print(f"  Exact Match: {metrics['exact_match']:.2f}%")
        print(f"{'='*50}")

        return metrics

    def train(self):
        """Full training loop"""
        print(f"\n{'='*60}")
        print("STARTING FEATURE-BASED FINE-TUNING")
        print(f"{'='*60}")

        for epoch in range(self.config.num_epochs):
            print(f"\n--- Epoch {epoch + 1}/{self.config.num_epochs} ---")
            epoch_loss = self.train_epoch(epoch)

            # Epoch-end evaluation
            eval_metrics = self.evaluate()

            # Save best model
            if eval_metrics['eval_loss'] < self.best_eval_loss:
                self.best_eval_loss = eval_metrics['eval_loss']
                self.save_checkpoint("best_model")
                print(f"✓ New best model saved (loss: {self.best_eval_loss:.4f})")

            self.model.train()

        print(f"\n{'='*60}")
        print("TRAINING COMPLETED")
        print(f"{'='*60}")

        # Save final model
        self.save_checkpoint("final_model")

        return self.training_history

    def save_checkpoint(self, name: str):
        """Save model checkpoint"""
        checkpoint_path = os.path.join(self.config.checkpoints_dir, f"{name}.pt")
        self.model.save_adapters(checkpoint_path)

# ============================================================================
# SECTION 8: Evaluation Pipeline (Pre vs Post Fine-tuning)
# ============================================================================

class ComparisonEvaluator:
    """Evaluate and compare pre-trained vs fine-tuned models"""

    def __init__(self, processor, config: FineTuningConfig):
        self.processor = processor
        self.config = config
        self.metrics_calc = MetricsCalculator()

    def evaluate_model(self, model, dataset, dataset_name: str, model_type: str):
        """Evaluate a model on dataset"""
        print(f"\n{'='*60}")
        print(f"Evaluating {model_type} model on {dataset_name}")
        print(f"{'='*60}")

        model.eval()
        predictions = []
        references = []
        inference_times = []

        max_samples = self.config.max_samples_eval or len(dataset)

        with torch.no_grad():
            for idx in tqdm(range(min(max_samples, len(dataset))), desc=f"Evaluating {model_type}"):
                inputs, reference = DatasetHandler.prepare_sample(
                    dataset[idx],
                    self.processor,
                    dataset_name
                )

                if inputs is None:
                    continue

                # Move to device
                device = next(model.parameters()).device
                gen_inputs = {k: v.to(device) for k, v in inputs.items() if k != 'labels'}

                # Measure inference time
                start_time = time.time()
                pred_ids = model.generate(**gen_inputs, max_new_tokens=100, do_sample=False)
                inference_time = time.time() - start_time

                pred_text = self.processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
                if "<|assistant|>" in pred_text:
                    pred_text = pred_text.split("<|assistant|>")[-1].strip()

                predictions.append(pred_text)
                references.append(reference)
                inference_times.append(inference_time)

                if idx % 10 == 0:
                    torch.cuda.empty_cache()

        # Calculate metrics
        metrics = self.metrics_calc.calculate_all_metrics(predictions, references)
        metrics['avg_inference_time'] = np.mean(inference_times)

        print(f"\nResults for {model_type}:")
        print(f"  BLEU-4: {metrics['bleu4']:.2f}")
        print(f"  ROUGE-1: {metrics['rouge1_f1']:.2f}")
        print(f"  ROUGE-L: {metrics['rougeL_f1']:.2f}")
        print(f"  Exact Match: {metrics['exact_match']:.2f}%")
        print(f"  Avg Inference Time: {metrics['avg_inference_time']:.4f} s")
        print(f"{'='*60}\n")

        return metrics

# ============================================================================
# SECTION 9: Utility functions and example script
# ============================================================================

def print_model_summary(model: nn.Module):
    """Print simple summary of model adapters and parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel summary:")
    print(f"  - Total params: {total_params:,}")
    print(f"  - Trainable params: {trainable:,}")
    print(f"  - Frozen params: {total_params - trainable:,}\n")


def safe_device():
    """Return appropriate device (cpu if no cuda)"""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")




In [None]:
if __name__ == "__main__":
    # NOTE: This example is intentionally conservative. Running the full training
    # loop requires GPU and enough memory. Use these calls as a template.
    try:
        # Configs
        model_cfg = ModelConfig()
        ft_cfg = FineTuningConfig()
        # The explicit setting `model_cfg.use_quantization = False` caused the model
        # to be too large for the GPU. Removing this line will re-enable 4-bit
        # quantization by default, allowing the model to fit on the GPU.

        # Load model & processor
        model_base, processor = Phi3VisionLoader.load_model_and_processor(model_cfg)

        # Wrap with adapters
        model_with_adapters = Phi3WithFeatureAdapters(model_base, ft_cfg)

        # Move adapters-only model to appropriate device if base model not already there
        device = safe_device()
        model_with_adapters.to(device)

        print_model_summary(model_with_adapters)

        # Load small subsets for quick smoke test (keep counts tiny to avoid long runs)
        train_ds = DatasetHandler.load_training_dataset(dataset_name="DocVQA", max_samples=ft_cfg.max_samples_train or 8)
        eval_ds = DatasetHandler.load_training_dataset(dataset_name="DocVQA", max_samples=ft_cfg.max_samples_eval or 4)

        # Create trainer
        trainer = FeatureAdapterTrainer(
            model=model_with_adapters,
            processor=processor,
            config=ft_cfg,
            train_dataset=train_ds,
            eval_dataset=eval_ds,
            dataset_name="DocVQA"
        )
        tokenizer = processor.tokenizer
        model_base.resize_token_embeddings(len(tokenizer))

        # WARNING: The full `trainer.train()` may be slow/expensive. Uncomment to run.
        history = trainer.train()

        # Instead, run a quick evaluation of the pretrained wrapped model (no training)
        comparator = ComparisonEvaluator(processor, ft_cfg)
        print("\nRunning quick pre-finetune evaluation (wrapped model with frozen base)...")
        pre_metrics = comparator.evaluate_model(model_with_adapters, eval_ds, "DocVQA", model_type="wrapped_pretrained")

        # If you have a saved adapter checkpoint, you can load & evaluate the fine-tuned adapters:
        # adapter_checkpoint = "./phi3_finetuning_outputs/checkpoints/best_model.pt"
        # model_with_adapters.load_adapters(adapter_checkpoint)
        # post_metrics = comparator.evaluate_model(model_with_adapters, eval_ds, "DocVQA", model_type="wrapped_finetuned")

        print("\nScript completed. To run training, uncomment `trainer.train()` in this file.")

    except Exception as e:
        print(f"\nAn error occurred in the example entrypoint: {e}")
        raise

In [None]:
CUDA_LAUNCH_BLOCKING=1