# Reproduce Existing Multimodal Model


In [None]:
from src.data import load_omnimed_dataset

In [None]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))

# Check for image overlap
print("Overlap train-test:", len(set(train_df['image_path']) & set(test_df['image_path'])))
print("Overlap train-val:", len(set(train_df['image_path']) & set(val_df['image_path'])))


In [None]:
train_df

In [None]:
def prepare_omnimed_dataframe(df, include_answer=True):
    """
    Given OmniMedVQA DataFrame, prepare text + image pairs for multimodal training.
    Assumes df already has 'image_path'.
    """

    # Label comes from gt_answer
    df['label'] = df['gt_answer']

    # Build text input
    if include_answer:
        df['text_input'] = df.apply(
            lambda row: f"Question: {row['question']}\nAnswer: {row['gt_answer']}",
            axis=1
        )
    else:
        df['text_input'] = df['question'].apply(lambda q: f"Question: {q}\nAnswer:")

    return df[['image_path', 'text_input', 'label']]

train_ready = prepare_omnimed_dataframe(train_df, include_answer=True)
val_ready   = prepare_omnimed_dataframe(val_df,   include_answer=True)
test_ready  = prepare_omnimed_dataframe(test_df,  include_answer=False)

print(train_ready.head())

In [None]:
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)


In [None]:
# import torch
# from PIL import Image
# import requests
# from torchvision import transforms
# from open_flamingo import create_model_and_transforms
# from huggingface_hub import hf_hub_download


# # ---- Dummy image ----
# import requests
# from PIL import Image
# from io import BytesIO


# image = Image.open("E:\MediVision-Flare25\data\OmniMedVQA\Images\ACRIMA\Im002_ACRIMA.png").convert("RGB")
# # image_tensor shape: (C,H,W)
# image_tensor = image_processor(image)  # shape (3,H,W)

# # Add batch, T_img, F dimensions
# image_tensor = image_tensor.unsqueeze(0).unsqueeze(1).unsqueeze(2)  
# # shape: (1, 1, 1, 3, H, W)


# # ---- Text prompt ----
# prompt = "ACRIMA"

# # ---- Tokenize ----
# inputs = tokenizer(prompt, return_tensors="pt")

# # ---- Forward pass ----
# with torch.no_grad():
#     out = model.generate(
#         vision_x=image_tensor,   # add time dim
#         lang_x=inputs["input_ids"],
#         max_new_tokens=20
#     )

# print("Generated:", tokenizer.decode(out[0]))


In [None]:
from pathlib import Path
from torch.utils.data import Dataset
from PIL import Image
import torch

class OmniMedDataset(Dataset):
    def __init__(self, dataframe, image_processor, tokenizer, include_answer=True, image_root="data/OmniMedVQA/Images"):
        self.df = dataframe
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.include_answer = include_answer
        self.image_root = Path(image_root)  # ensure consistent paths

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Fix the image path
        image_path = Path(row["image_path"])

        # Remove redundant "Images/" if present
        parts = image_path.parts
        if parts[0] == "Images":
            image_path = Path(*parts[1:])  # remove leading "Images"

        # Join with root
        image_path = (self.image_root / image_path).resolve()

        if not image_path.exists():
            raise FileNotFoundError(f"Image not found: {image_path}")

        # Load & preprocess image
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.image_processor(image).unsqueeze(0).unsqueeze(0)  # (1,1,3,H,W)

        # Build text prompt
        if self.include_answer:
            text = f"Prompt: {row['text_input']}\nAnswer: {row['label']}"
        else:
            text = f"Prompt: {row['text_input']}\nAnswer:"

        # Tokenize
        inputs = self.tokenizer(text, return_tensors="pt", padding="longest")

        return {
            "vision_x": image_tensor,
            "lang_x": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": row["label"]
        }





In [None]:
# Data loaders
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    vision_x = torch.stack([item["vision_x"] for item in batch])
    
    # pad lang_x sequences
    lang_x = pad_sequence([item["lang_x"] for item in batch],
                          batch_first=True, padding_value=tokenizer.pad_token_id)
    
    attention_mask = pad_sequence([item["attention_mask"] for item in batch],
                                  batch_first=True, padding_value=0)
    
    labels = pad_sequence([item["lang_x"] for item in batch],  # causal LM
                          batch_first=True, padding_value=-100)
    
    return {
        "vision_x": vision_x,
        "lang_x": lang_x,
        "attention_mask": attention_mask,
        "labels": labels
    }


# # Take a random sample of 500 rows (adjust number as needed)
# train_ready_small = train_ready.sample(n=500, random_state=42)
# val_ready_small   = val_ready.sample(n=100, random_state=42)

# # Rebuild datasets using sampled data
# train_dataset = OmniMedDataset(train_ready_small, image_processor, tokenizer, include_answer=True)
# val_dataset   = OmniMedDataset(val_ready_small,   image_processor, tokenizer, include_answer=True)

# train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, num_workers=4)
# val_loader   = DataLoader(val_dataset,   batch_size=2, shuffle=False, collate_fn=collate_fn, num_workers=4)

train_ready_debug = train_ready.sample(n=20, random_state=42)
train_dataset_debug = OmniMedDataset(train_ready_debug, image_processor, tokenizer, include_answer=True)
train_loader = DataLoader(train_dataset_debug, batch_size=1, shuffle=True, collate_fn=collate_fn)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = model.to(device)

In [None]:
from peft import get_peft_model

def apply_lora_to_flamingo(model, config, adapter_name="default"):
    for block in model.lang_encoder.transformer.blocks:
        # Decoder layer attention
        if hasattr(block, "decoder_layer") and hasattr(block.decoder_layer, "attn"):
            block.decoder_layer.attn.Wqkv = get_peft_model(
                block.decoder_layer.attn.Wqkv, config, adapter_name=adapter_name
            )
        # Cross-attention layer
        if hasattr(block, "gated_cross_attn_layer") and hasattr(block.gated_cross_attn_layer, "attn"):
            attn = block.gated_cross_attn_layer.attn
            attn.to_q = get_peft_model(attn.to_q, config, adapter_name=adapter_name)
            attn.to_kv = get_peft_model(attn.to_kv, config, adapter_name=adapter_name)

    # Old decoder blocks
    for block in model.lang_encoder.old_decoder_blocks:
        block.attn.Wqkv = get_peft_model(block.attn.Wqkv, config, adapter_name=adapter_name)

    # Other gated cross-attn layers
    for block in model.lang_encoder.gated_cross_attn_layers:
        attn = block.attn
        attn.to_q = get_peft_model(attn.to_q, config, adapter_name=adapter_name)
        attn.to_kv = get_peft_model(attn.to_kv, config, adapter_name=adapter_name)

    print("✅ LoRA applied successfully!")


In [None]:
# Training loop
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss

optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler()

for epoch in range(3):  # demo: 3 epochs
    model.train()
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()

        # Move batch to GPU
        vision_x = batch["vision_x"].to(device, non_blocking=True)
        lang_x = batch["lang_x"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)

        # Mixed precision forward pass
        with torch.cuda.amp.autocast():
            out = model(
                vision_x=vision_x,
                lang_x=lang_x,
                attention_mask=attention_mask,
                labels=lang_x
            )
            loss = out.loss

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Free unused VRAM
        torch.cuda.empty_cache()

        print(f"Epoch {epoch}, Step {i}, loss={loss.item():.4f}")

    print(f"Epoch {epoch} finished")


In [None]:
# Evaluation
model.eval()
with torch.no_grad():
    for batch in val_loader:
        generated = model.generate(
            vision_x=batch["vision_x"],
            lang_x=batch["lang_x"],
            attention_mask=batch["attention_mask"],
            max_new_tokens=30
        )
        print(tokenizer.decode(generated[0]))


6. Key Points from Moor et al. (Med-Flamingo)

Initialize with pretrained Flamingo weights (you already did ✅).

Freeze most parameters, adapt only vision-language cross-attn and LM layers (LoRA ✅).

Format QA prompts as you already built: "Question: ... Answer:".

Evaluate with accuracy, BLEU, ROUGE on medical QA tasks.

In [None]:
import torch
print(torch.__version__)          # should say 2.0.1+cu118
print(torch.cuda.is_available())  # should be True
print(torch.version.cuda)         # should be 11.8
print(torch.cuda.get_device_name(0))
