<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [4]</a>'.</span>

# BLIP-2 Medical VQA Training - Trans-cerebellum (1 Epoch Validation)

Train BLIP-2 on fetal ultrasound trans-cerebellum brain images Q&A for 1 epoch to validate pipeline.

**Model**: Salesforce/blip2-opt-2.7b (8-bit)
**Data**: 5 images, 40 Q&A pairs, split 3/1/1
**Training time**: ~10 minutes

In [1]:
# Parameters (can be overridden by papermill)
num_images = 5
num_epochs = 1
batch_size = 1
learning_rate = 1e-4

In [2]:
import os
import sys
import torch
import pandas as pd
from pathlib import Path
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration, TrainingArguments, Trainer
from torch.utils.data import Dataset
import time

print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

CUDA: True
GPU: NVIDIA GeForce RTX 4070 Laptop GPU


In [3]:
# Paths
BASE = Path(r"C:\Users\elyas\Workspace\PyCharm\fada-v3")
DATA_DIR = BASE / "data/Fetal Ultrasound Labeled"
IMAGE_DIR = BASE / "data/Fetal Ultrasound/Trans-cerebellum"
OUTPUT_DIR = BASE / "outputs/blip2_trans_cerebellum"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "Salesforce/blip2-opt-2.7b"
print(f"Output: {OUTPUT_DIR}")

Output: C:\Users\elyas\Workspace\PyCharm\fada-v3\outputs\blip2_trans_cerebellum


<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [4]:
# Load data
df = pd.read_excel(DATA_DIR / "Trans-cerebellum_image_list.xlsx")
q_cols = [c for c in df.columns if c.startswith('Q')]

vqa_data = []
for _, row in df.head(num_images).iterrows():
    img_path = IMAGE_DIR / row['Image Name']
    if not img_path.exists():
        continue
    for q_col in q_cols:
        q = q_col.split('\n', 1)[1][:100] if '\n' in q_col else q_col[:100]
        a = str(row[q_col])
        if pd.notna(a) and a.lower() not in ['nan', 'none', '']:
            vqa_data.append({'image_path': str(img_path), 'question': q, 'answer': a})

# Split: 70% train, 15% val, 15% test
n_train = int(len(vqa_data) * 0.7)
n_val = int(len(vqa_data) * 0.15)
train_data = vqa_data[:n_train]
val_data = vqa_data[n_train:n_train+n_val]
test_data = vqa_data[n_train+n_val:]

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\elyas\\Workspace\\PyCharm\\fada-v3\\data\\Fetal Ultrasound Labeled\\Trans-cerebellum_image_list.xlsx'

In [None]:
# Load model with LoRA
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

processor = Blip2Processor.from_pretrained(MODEL_NAME)
model = Blip2ForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    load_in_8bit=True,
    device_map="auto"
)

# Prepare for training
model = prepare_model_for_kbit_training(model)

# Add LoRA adapters
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print(f"Memory: {model.get_memory_footprint() / 1e9:.2f} GB")

In [None]:
# Dataset
class VQADataset(Dataset):
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        img = Image.open(item['image_path']).convert('RGB')
        prompt = f"Question: {item['question']} Answer:"
        
        inputs = self.processor(
            images=img,
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=256,
            truncation=True
        )
        
        labels = self.processor.tokenizer(
            item['answer'],
            return_tensors="pt",
            padding="max_length",
            max_length=128,
            truncation=True
        )["input_ids"]
        
        return {
            'pixel_values': inputs['pixel_values'].squeeze(),
            'input_ids': inputs['input_ids'].squeeze(),
            'labels': labels.squeeze()
        }

train_dataset = VQADataset(train_data, processor)
val_dataset = VQADataset(val_data, processor)
print(f"Datasets ready: {len(train_dataset)} train, {len(val_dataset)} val")

In [None]:
# Training config
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    logging_steps=5,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    remove_unused_columns=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

print("Trainer ready")

In [None]:
# Train
start_time = time.time()
train_result = trainer.train()
training_time = time.time() - start_time

print(f"\nTraining complete!")
print(f"Time: {training_time/60:.2f} minutes")
print(f"Final loss: {train_result.training_loss:.4f}")

# Save
trainer.save_model(str(OUTPUT_DIR / "final_model"))
processor.save_pretrained(str(OUTPUT_DIR / "final_model"))
print(f"Model saved to {OUTPUT_DIR / 'final_model'}")

In [None]:
# Test inference
model.eval()
test_item = test_data[0]
test_img = Image.open(test_item['image_path']).convert('RGB')
test_prompt = f"Question: {test_item['question']} Answer:"

inputs = processor(images=test_img, text=test_prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=50)
    prediction = processor.batch_decode(outputs, skip_special_tokens=True)[0]

print(f"\nTest Inference:")
print(f"Q: {test_item['question']}")
print(f"Predicted: {prediction}")
print(f"Ground truth: {test_item['answer'][:100]}")

In [None]:
# Summary
summary = {
    'model': MODEL_NAME,
    'category': 'Trans-cerebellum',
    'num_images': num_images,
    'train_samples': len(train_data),
    'val_samples': len(val_data),
    'test_samples': len(test_data),
    'epochs': num_epochs,
    'training_time_min': training_time/60,
    'final_loss': train_result.training_loss,
    'output_dir': str(OUTPUT_DIR)
}

import json
with open(OUTPUT_DIR / 'training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for k, v in summary.items():
    print(f"{k}: {v}")
print("="*60)