# Reproduce Existing Multimodal Model


In [None]:
from src.data import load_omnimed_dataset

# Sets whether or not training loss and epochs will be printed
debug_prints = 1

In [None]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))


In [None]:
def prepare_omnimed_dataframe(df, include_answer=True):
    """
    Given OmniMedVQA DataFrame, prepare text + image pairs for multimodal training.
    Assumes df already has 'image_path'.
    """

    # Label comes from gt_answer
    df['label'] = df['gt_answer']

    # Build text input
    if include_answer:
        df['text_input'] = df.apply(
            lambda row: f"Question: {row['question']}\nAnswer: {row['gt_answer']}",
            axis=1
        )
    else:
        df['text_input'] = df['question'].apply(lambda q: f"Question: {q}\nAnswer:")

    return df[['image_path', 'text_input', 'label']]

train_ready = prepare_omnimed_dataframe(train_df, include_answer=True)
val_ready   = prepare_omnimed_dataframe(val_df,   include_answer=True)
test_ready  = prepare_omnimed_dataframe(test_df,  include_answer=False)

print(train_ready.head())

In [None]:
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)


In [None]:
from pathlib import Path
from torch.utils.data import Dataset
from PIL import Image
import torch

class OmniMedDataset(Dataset):
    def __init__(self, dataframe, image_processor, tokenizer, include_answer=True, image_root="data/OmniMedVQA/Images"):
        self.df = dataframe
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.include_answer = include_answer
        self.image_root = Path(image_root)  # ensure consistent paths

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Fix the image path
        image_path = Path(row["image_path"])

        # Remove redundant "Images/" if present
        parts = image_path.parts
        if parts[0] == "Images":
            image_path = Path(*parts[1:])  # remove leading "Images"

        # Join with root
        image_path = (self.image_root / image_path).resolve()

        if not image_path.exists():
            raise FileNotFoundError(f"Image not found: {image_path}")

        # Load & preprocess image
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.image_processor(image).unsqueeze(0).unsqueeze(0)  # (1,1,3,H,W)

        # Build text prompt
        if self.include_answer:
            text = f"Prompt: {row['text_input']}\nAnswer: {row['label']}"
        else:
            text = f"Prompt: {row['text_input']}\nAnswer:"

        # Tokenize
        inputs = self.tokenizer(text, return_tensors="pt", padding="longest")

        return {
            "vision_x": image_tensor,
            "lang_x": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": row["label"]
        }





In [None]:
# Data loaders
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_fn_train(batch):
    """Right padding for training (loss computation)."""
    vision_x = torch.stack([item["vision_x"] for item in batch])

    lang_x = pad_sequence([item["lang_x"] for item in batch],
                          batch_first=True, padding_value=tokenizer.pad_token_id)

    attention_mask = pad_sequence([item["attention_mask"] for item in batch],
                                  batch_first=True, padding_value=0)

    labels = lang_x.clone()
    labels[labels == tokenizer.pad_token_id] = -100  # ignore pad in loss

    return {
        "vision_x": vision_x,
        "lang_x": lang_x,
        "attention_mask": attention_mask,
        "labels": labels,
    }


def collate_fn_eval(batch):
    """Left padding for generation (MosaicGPT requirement)."""
    vision_x = torch.stack([item["vision_x"] for item in batch])

    # pad_sequence by default left pads if we flip the sequences first
    seqs = [item["lang_x"].flip(0) for item in batch]  # reverse each seq
    lang_x = pad_sequence(seqs, batch_first=True,
                          padding_value=tokenizer.pad_token_id)
    lang_x = lang_x.flip(1)  # flip back so text is in the right place

    attn = [item["attention_mask"].flip(0) for item in batch]
    attention_mask = pad_sequence(attn, batch_first=True, padding_value=0)
    attention_mask = attention_mask.flip(1)

    labels = lang_x.clone()
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "vision_x": vision_x,
        "lang_x": lang_x,
        "attention_mask": attention_mask,
        "labels": labels,
    }




train_ready = train_ready.sample(n=100)
train_dataset = OmniMedDataset(train_ready, image_processor, tokenizer, include_answer=True)


val_ready = val_ready.sample(n=50)  # pick 10 samples for debugging
val_dataset = OmniMedDataset(val_ready, image_processor, tokenizer, include_answer=True)

test_ready = test_ready.sample(n=20)  
test_dataset = OmniMedDataset(test_ready, image_processor, tokenizer, include_answer=True)

train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_fn_train)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn_eval)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn_eval)



In [None]:
print(model)

In [None]:
from peft import get_peft_model

def apply_lora_to_flamingo(model, config, adapter_name="default"):
    for block in model.lang_encoder.transformer.blocks:
        # Decoder layer attention
        if hasattr(block, "decoder_layer") and hasattr(block.decoder_layer, "attn"):
            block.decoder_layer.attn.Wqkv = get_peft_model(
                block.decoder_layer.attn.Wqkv, config, adapter_name=adapter_name
            )
        # Cross-attention layer
        if hasattr(block, "gated_cross_attn_layer") and hasattr(block.gated_cross_attn_layer, "attn"):
            attn = block.gated_cross_attn_layer.attn
            attn.to_q = get_peft_model(attn.to_q, config, adapter_name=adapter_name)
            attn.to_kv = get_peft_model(attn.to_kv, config, adapter_name=adapter_name)

    # Old decoder blocks
    for block in model.lang_encoder.old_decoder_blocks:
        block.attn.Wqkv = get_peft_model(block.attn.Wqkv, config, adapter_name=adapter_name)

    # Other gated cross-attn layers
    for block in model.lang_encoder.gated_cross_attn_layers:
        attn = block.attn
        attn.to_q = get_peft_model(attn.to_q, config, adapter_name=adapter_name)
        attn.to_kv = get_peft_model(attn.to_kv, config, adapter_name=adapter_name)

    print("LoRA applied successfully!")


In [None]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss

num_epochs = 3
optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = CrossEntropyLoss()

for epoch in range(num_epochs):
    # ---- Training ----
    model.train()
    total_train_loss = 0
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        out = model(
            vision_x=batch["vision_x"],
            lang_x=batch["lang_x"],
            attention_mask=batch["attention_mask"],
            labels=batch["lang_x"]
        )
        loss = out.loss
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Step {i}, loss={loss:.4f}")

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch}: avg_train_loss={avg_train_loss:.4f}")

    # ---- Validation ----
    model.eval()
    total_val_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            out = model(
                vision_x=batch["vision_x"],
                lang_x=batch["lang_x"],
                attention_mask=batch["attention_mask"],
                labels=batch["lang_x"]
            )
            total_val_loss += out.loss.item()

            # Generate predictions
            generated = model.generate(
                vision_x=batch["vision_x"],
                lang_x=batch["lang_x"],
                attention_mask=batch["attention_mask"],
                max_new_tokens=30
            )
            decoded_preds = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]
            all_preds.extend(decoded_preds)

            labels = batch["labels"]
            decoded_labels = [tokenizer.decode(l[l != -100], skip_special_tokens=True) for l in labels]
            all_labels.extend(decoded_labels)

            print(f"Epoch {epoch}, Step {i}, val_loss={out.loss.item():.4f}")

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch}: val_loss={avg_val_loss:.4f}")


In [None]:
batch

In [None]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        # Generate text
        generated = model.generate(
            vision_x=batch["vision_x"],
            lang_x=batch["lang_x"],
            attention_mask=batch["attention_mask"],
            max_new_tokens=30
        )

        # Decode the predictions
        decoded_preds = [tokenizer.decode(g, skip_special_tokens=True) for g in generated]
        all_preds.extend(decoded_preds)

        # Decode the labels (ground-truth)
        labels = batch["labels"]
        decoded_labels = [tokenizer.decode(l[l != -100], skip_special_tokens=True) for l in labels]  # ignore padding
        all_labels.extend(decoded_labels)

        # Optional: print for inspection
        for pred, label in zip(decoded_preds, decoded_labels):
            print("Pred:", pred)
            print("GT  :", label)
            print("---")


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score
import re

def clean_text(text):
    text = text.strip()
    text = re.sub(r"\bAnswer\b$", "", text)        # remove dangling "Answer"
    text = text.replace("\nAnswer", "").strip()    # remove embedded newlines
    return text

# --- Extract answers ---
y_true = [clean_text(l.split("Answer:")[-1]) for l in all_labels]
y_pred = [clean_text(p.split("Answer:")[-1]) for p in all_preds]

print("y_true:", y_true)
print("y_pred:", y_pred)

# --- Token-level partial matching ---
def token_overlap_scores(gt, pred):
    gt_tokens = set(gt.split())
    pred_tokens = set(pred.split())

    if not gt_tokens and not pred_tokens:
        return 1, 1, 1  # perfect match if both empty
    if not gt_tokens:
        return 0, 1, 0  # only prediction present
    if not pred_tokens:
        return 1, 0, 0  # only ground truth present

    tp = len(gt_tokens & pred_tokens)
    fp = len(pred_tokens - gt_tokens)
    fn = len(gt_tokens - pred_tokens)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1

# --- Compute scores across all examples ---
precisions, recalls, f1s = [], [], []
matches = []

for gt, pred in zip(y_true, y_pred):
    precision, recall, f1 = token_overlap_scores(gt, pred)
    precisions.append(precision)
    recalls.append(recall)
    f1s.append(f1)
    matches.append(1 if gt == pred else 0)

accuracy = sum(matches) / len(matches)
precision = sum(precisions) / len(precisions)
recall = sum(recalls) / len(recalls)
f1 = sum(f1s) / len(f1s)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1:", f1)


6. Key Points from Moor et al. (Med-Flamingo)

Initialize with pretrained Flamingo weights

Freeze most parameters, adapt only vision-language cross-attn and LM layers (LoRA).

Format QA prompts as you already built: "Question: ... Answer:".

Evaluate with accuracy, BLEU, ROUGE on medical QA tasks.