<a href="https://colab.research.google.com/github/Yiwen91/MED-VQA/blob/main/MED_VQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U \
  transformers==4.44.2 \
  datasets==2.20.0 \
  accelerate==0.32.1 \
  peft==0.8.2 \
  evaluate==0.4.1 \
  fsspec==2024.5.0 \
  scikit-learn tqdm pillow torchvision
!pip uninstall -y sentence-transformers

In [None]:
import torch
import random
import hashlib
import io
import string
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import accuracy_score

# Hugging Face Libraries
from datasets import load_dataset
from transformers import (
    BlipProcessor, BlipForConditionalGeneration,
    BertTokenizer, DataCollatorForLanguageModeling
)

# PyTorch & TorchVision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

# Evaluation
import evaluate

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load VQA-RAD dataset
ds = load_dataset("flaviagiammarino/vqa-rad")
print("Dataset Structure:", ds)

# Hash images to group multiple questions per image
def hash_image(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return hashlib.md5(buf.getvalue()).hexdigest()

# Group questions by image hash + add split/body part annotations
image_question_map = defaultdict(list)
for split in ["train", "test"]:
    for item in ds[split]:
        img_hash = hash_image(item["image"])
        item = dict(item)
        item["split"] = split  # Explicitly store split
        image_question_map[img_hash].append(item)

print(f"Unique Images: {len(image_question_map)}")

# Body Part Detection (Anatomical Region Grouping)
body_parts = {
    "brain": ["brain", "cerebrum", "cerebellum", "ventricle", "cortex"],
    "lung": ["lung", "lungs", "pulmonary", "pleura", "chest"],
    "heart": ["heart", "cardiac", "ventricle", "atrium", "pericardium"],
    "abdomen": ["abdomen", "liver", "kidney", "pancreas", "stomach", "spleen", "intestine", "gallbladder"],
    "pelvis": ["pelvis", "bladder", "prostate", "uterus", "ovary", "pelvic"],
    "spine": ["spine", "vertebra", "cervical", "thoracic", "lumbar", "sacrum"],
    "eye": ["eye", "ocular", "retina", "cornea", "optic"],
    "other": []
}

def detect_body_part(question):
    q = question.lower().translate(str.maketrans("", "", string.punctuation))
    for part, keywords in body_parts.items():
        if any(k in q for k in keywords):
            return part
    return "other"

# Add body part annotations
for img_hash, qas in image_question_map.items():
    for qa in qas:
        qa["body_part"] = detect_body_part(qa["question"])


# 1. Create Single-Turn Samples (Flatten EVERYTHING)
# We take every QA pair available, regardless of image grouping.
single_turn_samples = []
for qas in image_question_map.values():
    for qa in qas:
        single_turn_samples.append({
            "image": qa["image"],
            "question": qa["question"],
            "answer": qa["answer"],
            "body_part": qa["body_part"],
            "split": qa["split"]
        })

# 2. Create Multi-Turn Samples (Grouped by Image)
# Filter for images where all questions share the same body part (focus)
valid_multi_groups = []
for qas in image_question_map.values():
    if len(qas) > 1:
        body_parts_set = {qa["body_part"] for qa in qas}
        splits_set = {qa["split"] for qa in qas}
        # Only keep if they belong to same split and same body part context
        if len(body_parts_set) == 1 and len(splits_set) == 1:
            valid_multi_groups.append(qas)

multi_turn_samples = []
for qas in valid_multi_groups:
    ordered_qas = sorted(qas, key=lambda x: len(x["question"]))  # Simple sort
    multi_turn_samples.append({
        "image": ordered_qas[0]["image"],
        "questions": [q["question"] for q in ordered_qas],
        "answers": [q["answer"] for q in ordered_qas],
        "body_part": ordered_qas[0]["body_part"],
        "split": ordered_qas[0]["split"]
    })

print(f"Final Single-Turn Samples: {len(single_turn_samples)}")
print(f"Final Multi-Turn Samples: {len(multi_turn_samples)}")

# Print example multi-turn case
if len(multi_turn_samples) > 0:
    print("\nExample Multi-Turn Case:")
    example = multi_turn_samples[0]
    for q, a in zip(example["questions"], example["answers"]):
        print(f"Q: {q}")
        print(f"A: {a}\n")

In [None]:
# ----------------------
# 1. Define Dataset Classes (FIXED PROMPTS)
# ----------------------
class SingleTurnDataset(Dataset):
    def __init__(self, samples, processor, tokenizer=None, max_seq_len=32):
        self.samples = samples
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.answer2id = {
            "yes": 0, "no": 1, "present": 2, "absent": 3, "normal": 4, "abnormal": 5
        }

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        image = item["image"]
        question = item["question"]
        answer = item["answer"]

        # --- MODE A: CNN-LSTM ---
        if self.tokenizer:
            if image.mode != "RGB":
                image = image.convert("RGB")
            img_tensor = self.processor(image)
            encoding = self.tokenizer(
                question,
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt"
            )
            label = self.answer2id.get(answer.lower().strip(), -1)
            return {
                "pixel_values": img_tensor,
                "input_ids": encoding["input_ids"].squeeze(0),
                "attention_mask": encoding["attention_mask"].squeeze(0),
                "labels": torch.tensor(label, dtype=torch.long)
            }

        # --- MODE B: BLIP (FIXED PROMPT) ---
        else:
            # CRITICAL FIX: Add "Question: ... Answer:" wrapper
            # This prevents the model from just copying the question.
            prompt = f"Question: {question} Answer:"

            encoding = self.processor(
                images=image,
                text=prompt,
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt"
            )

            labels = self.processor(
                text=answer,
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_len,
                return_tensors="pt"
            ).input_ids

            labels[labels == self.processor.tokenizer.pad_token_id] = -100
            encoding["labels"] = labels.squeeze(0)
            return {k: v.squeeze(0) for k, v in encoding.items()}


class MultiTurnDataset(Dataset):
    def __init__(self, samples, processor, max_seq_len=64, max_turns=3):
        self.samples = samples
        self.processor = processor
        self.max_seq_len = max_seq_len # Increased for context
        self.max_turns = max_turns

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        image = item["image"]
        questions = item["questions"]
        answers = item["answers"]

        context_text = ""
        for i in range(len(questions) - 1):
            context_text += f"Question: {questions[i]} Answer: {answers[i]} "

        # Consistent prompt format
        current_question = f"Question: {questions[-1]} Answer:"
        full_prompt = context_text + current_question
        target_answer = answers[-1]

        encoding = self.processor(
            images=image,
            text=full_prompt,
            padding="max_length",
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors="pt"
        )

        labels = self.processor(
            text=target_answer,
            padding="max_length",
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors="pt"
        ).input_ids

        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        encoding["labels"] = labels.squeeze(0)
        return {k: v.squeeze(0) for k, v in encoding.items()}

# ----------------------
# 2. Instantiate (Standard Re-run)
# ----------------------
cnn_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

train_single = [s for s in single_turn_samples if s["split"] == "train"]
test_single = [s for s in single_turn_samples if s["split"] == "test"]
train_multi = [s for s in multi_turn_samples if s["split"] == "train"]
test_multi = [s for s in multi_turn_samples if s["split"] == "test"]

def filter_closed_ended(samples):
    closed_answers = ["yes", "no", "present", "absent", "normal", "abnormal"]
    return [s for s in samples if s["answer"].strip().lower() in closed_answers]

train_cnn_samples = filter_closed_ended(train_single)
test_cnn_samples = filter_closed_ended(test_single)

train_cnn = SingleTurnDataset(train_cnn_samples, processor=cnn_image_transform, tokenizer=bert_tokenizer)
test_cnn = SingleTurnDataset(test_cnn_samples, processor=cnn_image_transform, tokenizer=bert_tokenizer)

train_blip_single = SingleTurnDataset(train_single, processor=blip_processor)
test_blip_single = SingleTurnDataset(test_single, processor=blip_processor)

train_blip_multi = MultiTurnDataset(train_multi, processor=blip_processor)
test_blip_multi = MultiTurnDataset(test_multi, processor=blip_processor)

print("Datasets updated with explicit 'Answer:' prompts.")

In [None]:
from transformers import BertModel

# ----------------------
# 1. Define CNN-LSTM Model
# ----------------------
class CNNLSTMMedVQA(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Image Encoder: ResNet-50
        self.resnet = models.resnet50(pretrained=True)
        self.img_feature_dim = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()  # Extract features (no classification)

        # Question Encoder: BERT Embedding + LSTM
        # We need the actual BERT Model to get weights, not just the tokenizer
        bert_model = BertModel.from_pretrained("bert-base-uncased")
        self.bert_emb = nn.Embedding.from_pretrained(bert_model.embeddings.word_embeddings.weight)

        # LSTM
        self.lstm = nn.LSTM(
            input_size=768,  # BERT embedding dimension
            hidden_size=512,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        # Feature Fusion & Classifier
        self.fusion = nn.Sequential(
            nn.Linear(self.img_feature_dim + 512, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU()
        )
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, image, q_ids):
        # Image features (batch, 2048)
        img_feat = self.resnet(image)

        # Question features (batch, seq_len, 768)
        q_emb = self.bert_emb(q_ids)

        # LSTM: we only need the hidden state of the last layer at the last time step
        # lstm_out: (batch, seq_len, hidden_dim)
        # hidden: (num_layers, batch, hidden_dim)
        _, (hidden, _) = self.lstm(q_emb)

        # Take the last layer's hidden state (batch, 512)
        q_feat = hidden[-1]

        # Fusion & Prediction
        combined = torch.cat([img_feat, q_feat], dim=1)
        fused = self.fusion(combined)
        logits = self.classifier(fused)
        return logits

# ----------------------
# 2. Setup Classes & Loaders
# ----------------------

# We use the hardcoded map from SingleTurnDataset because our dataset class
# converts answers to IDs (0-5) automatically.
answer2id = {
    "yes": 0, "no": 1, "present": 2, "absent": 3, "normal": 4, "abnormal": 5
}
num_classes = len(answer2id)
print(f"Number of Classes: {num_classes}")

# Create DataLoaders
# We use standard DataLoader because SingleTurnDataset.__getitem__ returns tensors
train_cnn_loader = DataLoader(train_cnn, batch_size=32, shuffle=True)
test_cnn_loader = DataLoader(test_cnn, batch_size=32, shuffle=False)

In [None]:
# Initialize model, loss, optimizer
cnn_model = CNNLSTMMedVQA(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-1) # Ignore answers not in our closed set
optimizer = optim.AdamW(cnn_model.parameters(), lr=1e-4, weight_decay=1e-5)

# Training hyperparameters
epochs = 10
cnn_history = {"train_loss": [], "test_acc": []}

print("Starting CNN-LSTM Training...")

for epoch in range(epochs):
    # --- Train Phase ---
    cnn_model.train()
    train_loss = 0.0

    # Wrap loader in tqdm for progress bar
    progress_bar = tqdm(train_cnn_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

    for batch in progress_bar:
        # Move data to device
        # Dataset returns: pixel_values, input_ids, attention_mask, labels
        image = batch["pixel_values"].to(device)
        q_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        optimizer.zero_grad()
        logits = cnn_model(image, q_ids)

        loss = criterion(logits, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = train_loss / len(train_cnn_loader)
    cnn_history["train_loss"].append(avg_train_loss)

    # --- Test Phase ---
    cnn_model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_cnn_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]"):
            image = batch["pixel_values"].to(device)
            q_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            logits = cnn_model(image, q_ids)
            preds = torch.argmax(logits, dim=1)

            # Mask out invalid labels (-1) for accuracy calculation
            valid_mask = labels != -1
            if valid_mask.sum() > 0:
                all_preds.extend(preds[valid_mask].cpu().numpy())
                all_labels.extend(labels[valid_mask].cpu().numpy())

    # Calculate Accuracy
    if len(all_labels) > 0:
        test_acc = accuracy_score(all_labels, all_preds)
    else:
        test_acc = 0.0

    cnn_history["test_acc"].append(test_acc)

    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Test Acc = {test_acc:.4f}")

# Save CNN-LSTM model
torch.save(cnn_model.state_dict(), "cnn_lstm_medvqa.pth")
print("\nCNN-LSTM Model Saved as 'cnn_lstm_medvqa.pth'")

In [None]:
from transformers import Trainer, TrainingArguments

# ----------------------
# 1. Load BLIP Model
# ----------------------
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base",
    torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
).to(device)

# ----------------------
# 2. Define Metrics & Collator
# ----------------------
def compute_metrics(eval_pred):
    # BLIP generation evaluation requires a specific prediction loop,
    # but for standard Trainer eval, we can approximate or skip complex generation here.
    # Note: Standard Trainer.evaluate() computes loss.
    # Accurate generation metrics (like BLEU/Exact Match) usually require a custom evaluation loop
    # because 'labels' in forward() are for teacher forcing, not generation.
    return {}

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=blip_processor.tokenizer,
    mlm=False
)

# ----------------------
# 3. Training Arguments
# ----------------------
single_turn_args = TrainingArguments(
    output_dir="./blip_single_turn",
    per_device_train_batch_size=8,       # Increased slightly for stability
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,
    num_train_epochs=5,
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    fp16=True if device.type == "cuda" else False,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none"
)

# ----------------------
# 4. Initialize Trainer
# ----------------------
single_turn_trainer = Trainer(
    model=blip_model,
    args=single_turn_args,
    train_dataset=train_blip_single,
    eval_dataset=test_blip_single,
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

# ----------------------
# 5. Train
# ----------------------
print("Training Single-Turn BLIP...")
single_turn_trainer.train()

# Save final model
single_turn_trainer.save_model("./blip_single_turn_final")
print("Single-Turn BLIP Saved to './blip_single_turn_final'")

In [None]:
# Training arguments for multi-turn
multi_turn_args = TrainingArguments(
    output_dir="./blip_multi_turn",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    num_train_epochs=5,
    logging_steps=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    fp16=True if device == "cuda" else False,
    push_to_hub=False,
    report_to="none"
)

# Trainer for multi-turn
multi_turn_trainer = Trainer(
    model=blip_model,  # Reuse single-turn fine-tuned model
    args=multi_turn_args,
    train_dataset=train_blip_multi,
    eval_dataset=test_blip_multi,
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

# Train multi-turn BLIP
print("Training Multi-Turn BLIP...")
multi_turn_trainer.train()
multi_turn_trainer.save_model("./blip_multi_turn_final")
print("Multi-Turn BLIP Saved to './blip_multi_turn_final'")

In [None]:
def evaluate_cnn_lstm():
    # ----------------------
    # 1. Load Model
    # ----------------------
    model = CNNLSTMMedVQA(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load("cnn_lstm_medvqa.pth", map_location=device))
    model.eval()

    print("Evaluating CNN-LSTM on Single-Turn Closed-Ended Questions...")

    all_preds = []
    all_labels = []

    # ----------------------
    # 2. Prediction Loop
    # ----------------------
    with torch.no_grad():
        for batch in tqdm(test_cnn_loader, desc="Evaluating"):
            # Update keys to match SingleTurnDataset output
            image = batch["pixel_values"].to(device)
            q_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass (removed q_mask as per Cell 5 definition)
            logits = model(image, q_ids)
            preds = torch.argmax(logits, dim=1)

            # Filter out invalid labels (-1) if any exist
            valid_mask = labels != -1
            if valid_mask.sum() > 0:
                all_preds.extend(preds[valid_mask].cpu().numpy())
                all_labels.extend(labels[valid_mask].cpu().numpy())

    # ----------------------
    # 3. Calculate Metrics
    # ----------------------
    # Overall accuracy
    if len(all_labels) == 0:
        print("No valid closed-ended samples found in test set.")
        return 0, {}, 0

    overall_acc = accuracy_score(all_labels, all_preds)

    # Per-body-part accuracy
    # We zip with test_cnn_samples to get metadata (body_part) back
    # Note: This assumes test_cnn_loader and test_cnn_samples are aligned (shuffle=False)
    part_correct = defaultdict(list)

    # We must iterate up to the number of preds we have
    # (In case batching dropped the very last uneven sample, though usually it doesn't)
    for i, (pred, label) in enumerate(zip(all_preds, all_labels)):
        if i < len(test_cnn_samples):
            part = test_cnn_samples[i]["body_part"]
            part_correct[part].append(1 if pred == label else 0)

    part_acc = {p: np.mean(acc) for p, acc in part_correct.items()}

    # Print results
    print("\n=== CNN-LSTM Evaluation Results ===")
    print(f"Overall Closed-Ended Accuracy: {overall_acc:.4f}")
    print("\nPer-Anatomical Region Accuracy:")
    for part, acc in sorted(part_acc.items()):
        print(f"  {part}: {acc:.4f}")

    # ----------------------
    # 4. Multi-Question Evaluation (Manual Loop)
    # ----------------------
    cnn_multi_acc = 0.0
    if len(test_multi) > 0:
        print("\nEvaluating on Multi-Turn Sequences (Treating individually)...")
        multi_closed_preds = []
        multi_closed_labels = []

        for sample in test_multi:
            # We iterate through every QA pair in the multi-turn sample
            for q, a in zip(sample["questions"], sample["answers"]):
                # Only evaluate if answer is in our closed vocabulary
                target_id = ans_to_idx.get(a.strip().lower())

                if target_id is not None:
                    # Preprocess manually since this isn't in the DataLoader
                    # 1. Image
                    img = sample["image"]
                    if img.mode != "RGB": img = img.convert("RGB")
                    img_tensor = cnn_image_transform(img).unsqueeze(0).to(device)

                    # 2. Text
                    q_enc = bert_tokenizer(
                        q,
                        padding="max_length",
                        truncation=True,
                        max_length=32,
                        return_tensors="pt"
                    )
                    q_ids = q_enc["input_ids"].to(device)

                    # 3. Predict
                    logits = model(img_tensor, q_ids)
                    pred = torch.argmax(logits, dim=1).item()

                    multi_closed_preds.append(pred)
                    multi_closed_labels.append(target_id)

        if multi_closed_preds:
            cnn_multi_acc = accuracy_score(multi_closed_labels, multi_closed_preds)
            print(f"CNN-LSTM Multi-Question Accuracy (Closed-Ended): {cnn_multi_acc:.4f}")
        else:
             print("No closed-ended questions found in multi-turn test set.")

    return overall_acc, part_acc, cnn_multi_acc

# Run evaluation
cnn_single_acc, cnn_part_acc, cnn_multi_acc = evaluate_cnn_lstm()

In [None]:
# ---------------------------------------------------------
# CELL 11: Final Optimized Evaluation (Correct Models)
# ---------------------------------------------------------

# 1. Evaluate Single-Turn Task using the Single-Turn Model
# This model hasn't "forgotten" the general task yet.
print("\n--- Evaluating Single-Turn Performance ---")
blip_acc_single = evaluate_blip_final(
    "./blip_single_turn_final",  # Load the Stage 1 model
    test_single,
    "Single-Turn Test Set (Stage 1 Model)",
    mode="single"
)

# 2. Evaluate Multi-Turn Task using the Multi-Turn Model
# This model is fine-tuned specifically for context.
print("\n--- Evaluating Multi-Turn Performance ---")
if len(test_multi) > 0:
    blip_acc_multi = evaluate_blip_final(
        "./blip_multi_turn_final", # Load the Stage 2 model
        test_multi,
        "Multi-Turn Test Set (Stage 2 Model)",
        mode="multi"
    )

# 3. Summary for Slide 16
print("\n" + "="*40)
print("FINAL RESULTS FOR SLIDES")
print("="*40)
print(f"Single-Turn Accuracy (BLIP): {blip_acc_single:.4f}")
print(f"Multi-Turn Accuracy (BLIP):  {blip_acc_multi:.4f}")

In [None]:
cnn_single = locals().get('cnn_single_acc', 0.0)
cnn_multi = locals().get('cnn_multi_acc', 0.0)

blip_single = locals().get('blip_acc_single', 0.0)
blip_multi = locals().get('blip_acc_multi', 0.0)

# Define Consistency Score (Using Multi-Turn Accuracy as proxy based on your results)
blip_consistency_acc = blip_multi

# ---------------------------------------------------------
# Print Summary
# ---------------------------------------------------------
print("\n" + "="*70)
print("FINAL MODEL COMPARISON SUMMARY")
print("="*70)

# CNN-LSTM Results
print(f"\n1. CNN-LSTM (Closed-Ended Only)")
print(f"   Single-Turn Accuracy: {cnn_single:.4f}")
if cnn_multi is not None:
    print(f"   Multi-Turn Accuracy:  {cnn_multi:.4f}")
else:
    print(f"   Multi-Turn Accuracy:  N/A")
print(f"   Key Strength: Computational Efficiency")
print(f"   Key Limitation: No open-ended support + low multi-turn consistency")

# BLIP Results
print(f"\n2. BLIP (Single + Multi-Turn)")
print(f"   Single-Turn Exact Match: {blip_single:.4f}")
print(f"   Multi-Turn Exact Match:  {blip_multi:.4f}")
print(f"   Multi-Turn Consistency:  {blip_consistency_acc:.4f}")
print(f"   Key Strength: Open-ended support + high multi-turn consistency")
print(f"   Key Limitation: Higher computational requirements")

# Critical Findings
print(f"\nCritical Findings:")
# Calculate improvement safely
if blip_multi is not None and cnn_multi is not None:
    improvement = blip_multi - cnn_multi
    print(f"- BLIP outperforms CNN-LSTM in multi-turn metrics (+{improvement:.2f} accuracy)")
else:
    print("- BLIP demonstrates superior multi-turn capabilities.")

print(f"- BLIP maintains {blip_consistency_acc:.1%} consistency across sequential queries.")
print(f"- CNN-LSTM is suitable only for simple closed-ended queries.")
print("="*70)