#**MedPaLM**


##Zero-Shot

In [None]:
from transformers import pipeline
import torch

# Load a base backbone (Med-PaLM is instruction-tuned on top of FLAN-T5)
model_name = "google/flan-t5-large"
qa_pipeline = pipeline("text2text-generation", model=model_name,
                       device=0 if torch.cuda.is_available() else -1)

# List of zero-shot questions (all 10 tests)
questions = [
    # 1. Zero-shot QA
    "Context: Aspirin is used to reduce fever, pain, and inflammation. It prevents blood clots. Question: What does aspirin reduce?",

    # 2. Text Similarity
    "Compare the similarity between: 'Aspirin reduces inflammation' and 'Aspirin lowers swelling'. Return score 0–1 and short explanation.",

    # 3. Masked Language Modeling (MLM)
    "Fill in the blank: Insulin is produced by the ______.",

    # 4. Token Embeddings
    "Explain the roles of the tokens in the sentence: 'BRCA1 is a gene linked to breast cancer.'",

    # 5. Classification Hack (True/False Probe)
    "Is this statement true or false: Aspirin reduces fever.",

    # 6. Domain Mismatch QA
    "Context: Paris is the capital of France. Question: What is the capital of France?",

    # 7. Word Similarity
    "Are 'myocardial infarction' and 'heart attack' the same thing? Explain briefly.",

    # 8. Long Context QA
    "Context: Alzheimer's disease is a progressive neurodegenerative disorder. It is strongly associated with the accumulation of beta-amyloid plaques in the brain. Question: Which protein is implicated in Alzheimer's disease?",

    # 9. Contradiction Probe
    "Does aspirin reduce inflammation? Also, does aspirin increase inflammation?",

    # 10. Random Robustness QA
    "Context: Banana walks on Mars with a scalpel. Question: What walks on Mars?"
]

# Run Med-PaLM style zero-shot QA
for q in questions:
    prompt = f"Question: {q}\nAnswer:"
    result = qa_pipeline(prompt, max_new_tokens=64, do_sample=False)[0]['generated_text']
    print(f"Q: {q}")
    print(f"A: {result.strip()}")
    print("-" * 60)


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Device set to use cuda:0


Q: Context: Aspirin is used to reduce fever, pain, and inflammation. It prevents blood clots. Question: What does aspirin reduce?
A: fever
------------------------------------------------------------
Q: Compare the similarity between: 'Aspirin reduces inflammation' and 'Aspirin lowers swelling'. Return score 0–1 and short explanation.
A: 1
------------------------------------------------------------
Q: Fill in the blank: Insulin is produced by the ______.
A: pancreas
------------------------------------------------------------
Q: Explain the roles of the tokens in the sentence: 'BRCA1 is a gene linked to breast cancer.'
A: BRCA1 is a gene linked to breast cancer
------------------------------------------------------------
Q: Is this statement true or false: Aspirin reduces fever.
A: no
------------------------------------------------------------
Q: Context: Paris is the capital of France. Question: What is the capital of France?
A: Paris
------------------------------------------------

##Few-Shot

In [None]:
# -----------------------------------------------
# Setup & Imports
# -----------------------------------------------
!pip install transformers -q

import torch
from transformers import pipeline

# Load Med-PaLM backbone (FLAN-T5)
model_name = "google/flan-t5-large"
qa_pipeline = pipeline(
    "text2text-generation",
    model=model_name,
    device=0 if torch.cuda.is_available() else -1
)

# -----------------------------------------------
# Few-shot examples (demonstrations)
# -----------------------------------------------
demo_examples = """
Q: What is the main function of insulin?
A: Regulates blood sugar.

Q: What is BRCA1 linked to?
A: Breast cancer.

Q: What vitamin deficiency causes rickets?
A: Vitamin D.
"""

# -----------------------------------------------
# Test questions
# -----------------------------------------------
questions = [
    "What does aspirin reduce?",
    "What organ produces insulin?",
    "What deficiency causes scurvy?",
    "What disease is caused by HIV?",
    "Which protein is implicated in Alzheimer's disease?",
    "What is the capital of France?",
    "Are 'myocardial infarction' and 'heart attack' the same thing?",
    "Does aspirin reduce inflammation? Also, does aspirin increase inflammation?",
]

# -----------------------------------------------
# Run Few-shot QA
# -----------------------------------------------
for q in questions:
    # Build prompt = demonstrations + new question
    prompt = demo_examples + f"\nQ: {q}\nA:"
    result = qa_pipeline(prompt, max_new_tokens=64, do_sample=False)[0]['generated_text']

    print(f"Q: {q}")
    print(f"A: {result.strip()}")
    print("-" * 60)


Device set to use cuda:0


Q: What does aspirin reduce?
A: (D).
------------------------------------------------------------
Q: What organ produces insulin?
A: pancreas
------------------------------------------------------------
Q: What deficiency causes scurvy?
A: Vitamin C
------------------------------------------------------------
Q: What disease is caused by HIV?
A: Aids
------------------------------------------------------------
Q: Which protein is implicated in Alzheimer's disease?
A: -Tyrosine kinase
------------------------------------------------------------
Q: What is the capital of France?
A: paris
------------------------------------------------------------
Q: Are 'myocardial infarction' and 'heart attack' the same thing?
A: yes
------------------------------------------------------------
Q: Does aspirin reduce inflammation? Also, does aspirin increase inflammation?
A: Aspirin is an anti-inflammatory drug.
------------------------------------------------------------


#Fine-tuning

In [1]:
!pip install transformers datasets evaluate -q

import os, torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# ===== 1) Load FLAN-T5 (Med-PaLM backbone) =====
model_name = "google/flan-t5-base"   # use flan-t5-base if VRAM is limited
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# ===== 2) Load Dataset =====
dataset = load_dataset("Eladio/emrqa-msquad")

# ===== 3) Preprocess =====
max_input = 512
max_target = 64

def preprocess(ex):
    inputs = [f"question: {q} context: {c}" for q, c in zip(ex["question"], ex["context"])]
    targets = []
    for a in ex["answers"]:
        if isinstance(a, dict) and "text" in a:
            targets.append(a["text"][0])
        elif isinstance(a, list) and len(a) > 0:
            if isinstance(a[0], dict) and "text" in a[0]:
                targets.append(a[0]["text"])
            else:
                targets.append(str(a[0]))
        else:
            targets.append("")

    model_inputs = tokenizer(inputs, max_length=max_input, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=max_target, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)

# ===== 4) Data Collator =====
collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ===== 5) Training Arguments =====
training_args = Seq2SeqTrainingArguments(
    output_dir="./results_flan_t5_emrqa",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    save_total_limit=1,
    predict_with_generate=True,
    logging_steps=100,
    report_to=[]   # 🚫 disables W&B, TensorBoard, etc.
)

# ===== 6) Trainer =====
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"].select(range(3000)),   # small subset for Colab
    eval_dataset=tokenized["validation"].select(range(500)),
    data_collator=collator,
    tokenizer=tokenizer,
)

# ===== 7) Train + Evaluate =====
trainer.train()
print(trainer.evaluate())


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/670 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/43.9M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/130956 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/32739 [00:00<?, ? examples/s]

Map:   0%|          | 0/130956 [00:00<?, ? examples/s]

Map:   0%|          | 0/32739 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


Step,Training Loss
100,10.4551
200,1.3835
300,0.5081
400,0.4185
500,0.3865
600,0.3358
700,0.3625


{'eval_loss': 0.29760318994522095, 'eval_runtime': 26.0694, 'eval_samples_per_second': 19.18, 'eval_steps_per_second': 4.795, 'epoch': 1.0}


In [2]:
# ============================================
# Evaluate Fine-Tuned FLAN-T5 on emrqa-msquad
# ============================================

!pip install transformers datasets evaluate -q

import os, torch, evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# ==== Disable W&B ====
os.environ["WANDB_DISABLED"] = "true"

# ===== 1) Path to fine-tuned checkpoint =====
model_dir = "./results_flan_t5_emrqa/checkpoint-750"   # <-- change if needed
print(f"Loading fine-tuned model from: {model_dir}")

# ===== 2) Load tokenizer + model =====
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

# Pipeline for generation
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer,
                       device=0 if torch.cuda.is_available() else -1)

# ===== 3) Load dataset (validation set for testing) =====
dataset = load_dataset("Eladio/emrqa-msquad")
test_ds = dataset["validation"].select(range(200))  # test on 200 samples only

# ===== 4) Run predictions =====
preds, refs = [], []
for i, ex in enumerate(test_ds):
    prompt = f"question: {ex['question']} context: {ex['context']}"
    output = qa_pipeline(prompt, max_new_tokens=64, do_sample=False)[0]["generated_text"]

    # predictions in squad format
    preds.append({"id": str(i), "prediction_text": output})

    # normalize answers for squad metric
    answers = ex["answers"]
    if isinstance(answers, dict) and "text" in answers:
        texts = answers["text"]
    elif isinstance(answers, list):
        texts = [a["text"] if isinstance(a, dict) and "text" in a else a for a in answers]
    else:
        texts = [str(answers)]

    refs.append({"id": str(i), "answers": {"text": texts, "answer_start": [0]*len(texts)}})

# ===== 5) Evaluate with SQuAD metric =====
metric = evaluate.load("squad")
results = metric.compute(predictions=preds, references=refs)

print("\n📊 Fine-tuned FLAN-T5 (Med-PaLM backbone) Performance:")
print("Exact Match:", results["exact_match"])
print("F1:", results["f1"])


Loading fine-tuned model from: ./results_flan_t5_emrqa/checkpoint-750


Device set to use cuda:0
Token indices sequence length is longer than the specified maximum sequence length for this model (1171 > 512). Running this sequence through the model will result in indexing errors
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]


📊 Fine-tuned FLAN-T5 (Med-PaLM backbone) Performance:
Exact Match: 25.0
F1: 60.68106443854735


In [4]:
# ============================================
# Compare Fine-Tuned FLAN-T5 vs Raw T5-base/large
# ============================================

!pip install transformers datasets evaluate -q

import os, torch, evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# ==== Disable W&B ====
os.environ["WANDB_DISABLED"] = "true"

# ===== Load dataset =====
dataset = load_dataset("Eladio/emrqa-msquad")
test_ds = dataset["validation"].select(range(200))  # test on 200 samples

# ===== Helper: Evaluate a model =====
def evaluate_model(model_name, model_dir=None, label=""):
    print(f"\n🔹 Evaluating {label} ({model_name if model_dir is None else model_dir})")

    # Load model + tokenizer
    tok = AutoTokenizer.from_pretrained(model_dir or model_name, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir or model_name)

    qa_pipe = pipeline("text2text-generation", model=model, tokenizer=tok,
                       device=0 if torch.cuda.is_available() else -1)

    # Run predictions
    preds, refs = [], []
    for i, ex in enumerate(test_ds):
        prompt = f"question: {ex['question']} context: {ex['context']}"
        output = qa_pipe(prompt, max_new_tokens=64, do_sample=False)[0]["generated_text"]

        preds.append({"id": str(i), "prediction_text": output})

        # normalize answers for squad metric
        answers = ex["answers"]
        if isinstance(answers, dict) and "text" in answers:
            texts = answers["text"]
        elif isinstance(answers, list):
            texts = [a["text"] if isinstance(a, dict) and "text" in a else a for a in answers]
        else:
            texts = [str(answers)]

        refs.append({"id": str(i), "answers": {"text": texts, "answer_start": [0]*len(texts)}})

    # Compute metrics
    metric = evaluate.load("squad")
    results = metric.compute(predictions=preds, references=refs)

    print(f"📊 {label} Results")
    print("Exact Match:", results["exact_match"])
    print("F1:", results["f1"])
    return results

# ===== Run evaluations =====
# 1) Fine-tuned FLAN-T5
ft_results = evaluate_model("google/flan-t5-large", model_dir="./results_flan_t5_emrqa/checkpoint-750", label="Fine-tuned FLAN-T5")

# 2) Raw T5-base
base_results = evaluate_model("t5-base", label="Raw T5-base")

# 3) Raw T5-large
large_results = evaluate_model("t5-large", label="Raw T5-large")



🔹 Evaluating Fine-tuned FLAN-T5 (./results_flan_t5_emrqa/checkpoint-750)


Device set to use cuda:0
Token indices sequence length is longer than the specified maximum sequence length for this model (1171 > 512). Running this sequence through the model will result in indexing errors


📊 Fine-tuned FLAN-T5 Results
Exact Match: 25.0
F1: 60.68106443854735

🔹 Evaluating Raw T5-base (t5-base)


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Device set to use cuda:0


📊 Raw T5-base Results
Exact Match: 2.5
F1: 37.12529794467818

🔹 Evaluating Raw T5-large (t5-large)


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Device set to use cuda:0


📊 Raw T5-large Results
Exact Match: 8.5
F1: 42.93208483744865
