In [2]:
#sudo apt update
#sudo apt install -y python3 python3-venv python3-pip
#python3 -V                    # sanity check
#python3 -m venv .venv
#source .venv/bin/activate
#python -m pip install --upgrade pip
#pip install torch transformers evaluate rouge-score pandas
#python -m pip install --upgrade pip setuptools wheel
#pip install "bitsandbytes>=0.43.3" accelerate
#pip install scikit-learn

In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"  

import re, json, numpy as np, pandas as pd, torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import StoppingCriteria, StoppingCriteriaList
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer, scoring

print("CUDA available:", torch.cuda.is_available())

  from .autonotebook import tqdm as notebook_tqdm


CUDA available: True


In [4]:
MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507"
CSV_PATH = "data/MeDAL/pretrain_subset/test.csv"
HUMAN_CSV = "data/MeDAL/pretrain_subset/human_summaries_for_rouge.csv"
OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)


In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,   
)

In [6]:
print("Loading tokenizer & model...")

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if tok.pad_token_id is None and tok.eos_token_id is not None:
    tok.pad_token = tok.eos_token

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype="auto",
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        offload_folder="offload",
        max_memory={0: "30GiB", "cpu": "64GiB"},
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype="auto",
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        offload_folder="offload",
        max_memory={"cpu": "64GiB"},
    )

model.eval()
print("Model loaded successfully!")


Loading tokenizer & model...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 16/16 [02:04<00:00,  7.76s/it]


Model loaded successfully!


In [7]:
def build_inputs(tok, system_msg, user_text):
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_text},
    ]
    chat = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return messages, chat

def strip_all_think_blocks(t: str) -> str:
    # Remove any nested <think>...</think> sections robustly
    return re.sub(r"<think>.*?</think>", "", t, flags=re.DOTALL | re.IGNORECASE).strip()

def safe_trim_to_first_n_sentences(text, n=4, min_keep=3):
    text = strip_all_think_blocks(text)
    # Also remove any leading "Hmm," style prefaces
    text = re.sub(r"^\s*(?:Hmm[.,].*?\n+)+", "", text, flags=re.IGNORECASE | re.DOTALL).strip()
    sents = re.split(r"(?<=[.!?])\s+", text)
    if len(sents) > n:
        sents = sents[:n]
    if len(sents) < min_keep and len(text) > 0:
        return text
    return " ".join(sents).strip()

def postprocess(decoded: str) -> str:
    # Keep only the text after the LAST </think> (if any slipped through)
    if "</think>" in decoded:
        decoded = decoded.split("</think>")[-1]
    # If <summary>...</summary> present, extract the inner text
    m = re.search(r"<summary>(.*?)</summary>", decoded, flags=re.DOTALL | re.IGNORECASE)
    final = (m.group(1) if m else decoded).strip()
    # Normalize empty/ellipsis
    if final.strip() in {"", "...", "…"}:
        return ""
    return safe_trim_to_first_n_sentences(final, n=4, min_keep=3)

In [8]:
def get_embedding_device(model):
    try:
        return model.get_input_embeddings().weight.device
    except Exception:
        return next(model.parameters()).device

EMBED_DEVICE = get_embedding_device(model)
print("Embedding device:", EMBED_DEVICE)

class StopOnSubstrings(StoppingCriteria):
    """
    Stop generation once any of the provided stop strings appears at the end (token-wise).
    This is complemented by a decoded-text cut for robustness.
    """
    def __init__(self, stop_strings, tokenizer):
        super().__init__()
        self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_tokens = input_ids[0].tolist()
        for s_ids in self.stop_ids:
            if len(last_tokens) >= len(s_ids) and last_tokens[-len(s_ids):] == s_ids:
                return True
        return False

STOP_STRINGS = ["</summary>"]
STOPPER = StoppingCriteriaList([StopOnSubstrings(STOP_STRINGS, tok)])

def cut_after_last_summary_tag(text: str) -> str:
    """
    Hard-cut decoded text right after the first occurrence of </summary>.
    This complements token-level stopping and ensures a clean end.
    """
    end_tag = "</summary>"
    pos = text.find(end_tag)
    if pos != -1:
        return text[:pos + len(end_tag)]
    return text

def get_model_ctx(model, default_ctx=262144):
    """
    Try to read the model's max context length; fall back to a large safe default.
    """
    ctx = getattr(getattr(model, "config", object()), "max_position_embeddings", None)
    if isinstance(ctx, int) and ctx > 0:
        return ctx
    return default_ctx  # Qwen3 models advertise very long contexts

def generate_summary(model, tok, chat, system_msg_for_rebuild,
                     keep_tokens_for_answer=512, source_text=None,
                     gen_temp=0.4, gen_top_p=0.9, greedy=False):
    # Determine context and set a prompt budget
    max_ctx = int(get_model_ctx(model))
    reserve = keep_tokens_for_answer + 64
    prompt_budget = max(256, max_ctx - reserve)

    # Tokenize + move to embedding device
    inputs = tok([chat], return_tensors="pt", truncation=True, max_length=prompt_budget)
    inputs = {k: v.to(EMBED_DEVICE) for k, v in inputs.items()}

    # Rebuild from head if still too long
    if inputs["input_ids"].shape[1] >= prompt_budget and source_text:
        head_enc = tok(source_text, return_tensors="pt", truncation=True, max_length=prompt_budget)
        trimmed = tok.decode(head_enc["input_ids"][0], skip_special_tokens=True)
        _, chat2 = build_inputs(tok, system_msg_for_rebuild, trimmed)
        inputs = tok([chat2], return_tensors="pt", truncation=True, max_length=prompt_budget)
        inputs = {k: v.to(EMBED_DEVICE) for k, v in inputs.items()}

    # Ensure pad token again (defensive)
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token = tok.eos_token

    gen_kwargs = dict(
        max_new_tokens=320,
        eos_token_id=tok.eos_token_id,     # allow EOS to stop as well
        stopping_criteria=STOPPER,         # primary hard stop at </summary>
        no_repeat_ngram_size=3,
    )
    if greedy:
        gen_kwargs.update(dict(do_sample=False))
    else:
        gen_kwargs.update(dict(do_sample=True, temperature=gen_temp, top_p=gen_top_p))

    with torch.inference_mode():
        out = model.generate(**inputs, **gen_kwargs)

    # Decode only the newly generated tokens
    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    decoded = tok.decode(new_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)

    # Hard-cut after </summary> in decoded space for robustness
    decoded = cut_after_last_summary_tag(decoded)
    return decoded

Embedding device: cuda:0


In [16]:
try:
    df = pd.read_csv(CSV_PATH)
    source_text = str(df.iloc[0]["TEXT"]).strip()
    # capture ABSTRACT_ID if present for alignment with human references
    abs_id = int(df.iloc[0]["ABSTRACT_ID"]) if "ABSTRACT_ID" in df.columns else None
    print("Loaded CSV sample. Characters:", len(source_text))
except Exception as e:
    print("CSV load warning:", e)
    source_text = (
        "Background: Hypertension is a common cardiovascular risk factor. "
        "Methods: We conducted a randomized, controlled trial evaluating a new ACE inhibitor versus placebo "
        "in 1,200 adults with stage 2 hypertension over 24 weeks. "
        "Results: The treatment group showed a mean systolic BP reduction of 18 mmHg versus 6 mmHg with placebo; "
        "adverse events were mild and comparable. "
        "Conclusion: The ACE inhibitor significantly reduced blood pressure with acceptable safety."
    )
    abs_id = None
    print("Using fallback abstract. Characters:", len(source_text))

Loaded CSV sample. Characters: 1039


In [17]:
SYSTEM_MSG_TAGGED = (
    "Summarize the user's medical abstract in 3–4 sentences. "
    "Be clear and factual. Keep key clinical details (condition, intervention, measurements, outcomes). "
    "Return ONLY the summary wrapped exactly as:\n<summary>...</summary>\nNo preface, no analysis, no extra text."
)

SYSTEM_MSG_FALLBACK = (
    "Summarize the user's medical abstract in 3–4 sentences. "
    "Be clear and factual. Keep key clinical details (condition, intervention, measurements, outcomes). "
    "Output ONLY the 3–4 sentence summary—no preface, no analysis."
)

GEN_TEMP, GEN_TOPP = 0.4, 0.9

In [18]:
# Attempt 1
_, chat = build_inputs(tok, SYSTEM_MSG_TAGGED, source_text)
print("Generating summary (attempt 1)...")
decoded = generate_summary(
    model, tok, chat, system_msg_for_rebuild=SYSTEM_MSG_TAGGED,
    keep_tokens_for_answer=512, source_text=source_text,
    gen_temp=GEN_TEMP, gen_top_p=GEN_TOPP, greedy=False
)
summary = postprocess(decoded)

# Fallback: deterministic pass
if not summary:
    print("Generating summary (attempt 2, deterministic fallback)...")
    _, chat_fb = build_inputs(tok, SYSTEM_MSG_FALLBACK, source_text)
    decoded_fb = generate_summary(
        model, tok, chat_fb, system_msg_for_rebuild=SYSTEM_MSG_FALLBACK,
        keep_tokens_for_answer=512, source_text=source_text,
        gen_temp=0.7, gen_top_p=0.95, greedy=True
    )
    summary = postprocess(decoded_fb)

# Final guardrail: strip leading lone "I ..." lines
if summary and summary[:1].lower() == "i":
    summary = re.sub(r"(^|\n)I[^\n]*", "", summary).strip()

if not summary:
    preview = (decoded or "")[:400].replace("\n", " ")
    print("\n[Debug] Model raw (first 400 chars):", preview)
    summary = "Summary unavailable: the model did not produce a clean summary."

print("\n--- SUMMARY ---\n")
print(summary)

Generating summary (attempt 1)...

--- SUMMARY ---

Looking at the abstract, it describes a sheep model of allergic airWAY disease. The researchers exposed sheep to Ascaris allergen until they developed chronic inflammation. They divided the sheep into three groups: Group A (control, saline only), Group B (non-reactive), and Group C (reactive with stable increased airway resistance). Key measurements included airway mechanics (RL, FRC), lung function tests (plethysomography, helium rebreathe), and BAL analyses (cell counts, protein levels).


In [19]:
os.makedirs("outputs", exist_ok=True)
out_path = "outputs/summary_first_test.txt"
with open(out_path, "w", encoding="utf-8") as f:
    f.write(summary + "\n")
print(f"\nSaved to {out_path}")


Saved to outputs/summary_first_test.txt


In [None]:
# -----------------------------------------------------------------------------
# TF-IDF evaluation 
# -----------------------------------------------------------------------------
def split_sentences(text: str):
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]

def _safe_max_df(n_docs: int, max_df):
   
    if isinstance(max_df, float):
        if n_docs <= 10:
            return 1.0
        return max_df
    return max_df  

def tfidf_eval(source_text: str, summary: str,
               ngram_range=(1,2), max_df=0.9, min_df=1,
               coverage_threshold=0.10, use_stopwords=True):
   
    stop = 'english' if use_stopwords else None

    # --- Global similarity (2 docs) ---
    docs = [source_text, summary]
    vec_global = TfidfVectorizer(lowercase=True, stop_words=stop,
                                 ngram_range=ngram_range,
                                 max_df=_safe_max_df(len(docs), max_df),
                                 min_df=min_df)
    X = vec_global.fit_transform(docs)
    global_sim = float(cosine_similarity(X[0], X[1])[0, 0])

    # --- Coverage (many src sentences + 1 summary) ---
    src_sents = split_sentences(source_text)[:200] or [source_text]
    cov_docs = src_sents + [summary]
    vec_cov = TfidfVectorizer(lowercase=True, stop_words=stop,
                              ngram_range=ngram_range,
                              max_df=_safe_max_df(len(cov_docs), max_df),
                              min_df=min_df)
    X_cov = vec_cov.fit_transform(cov_docs)
    S, q = X_cov[:-1], X_cov[-1]
    sent_sims = cosine_similarity(S, q).ravel()
    coverage = float((sent_sims >= coverage_threshold).mean())

    # --- Redundancy (within-summary) ---
    summ_sents = split_sentences(summary)
    if len(summ_sents) >= 2:
        vec_red = TfidfVectorizer(lowercase=True, stop_words=stop,
                                  ngram_range=ngram_range,
                                  max_df=1.0, min_df=1)  
        X_red = vec_red.fit_transform(summ_sents)
        C = cosine_similarity(X_red)
        redundancy = float((C.sum() - np.trace(C)) / (C.shape[0]*C.shape[1] - C.shape[0]))
    else:
        redundancy = 0.0

    # --- Top keywords (single doc: summary) ---
    top_keywords = []
    if len(re.findall(r"\w+", summary)) >= 2:
        vec_kw = TfidfVectorizer(lowercase=True, stop_words=stop,
                                 ngram_range=ngram_range,
                                 max_df=1.0, min_df=1)  
        try:
            X_kw = vec_kw.fit_transform([summary])
            vocab = np.array(vec_kw.get_feature_names_out())
            scores = X_kw.toarray()[0]
            top_idx = scores.argsort()[::-1][:10]
            top_keywords = [(vocab[i], float(scores[i])) for i in top_idx if scores[i] > 0]
        except ValueError:
            top_keywords = []

    return {
        "tfidf_cosine_similarity": global_sim,
        f"coverage@{coverage_threshold:.2f}": coverage,
        "redundancy_avg_pairwise": redundancy,
        "top_keywords": top_keywords,
        "notes": "Higher similarity & coverage are good; lower redundancy is better."
    }

metrics = tfidf_eval(
    source_text,
    summary,
    ngram_range=(1,2),
    max_df=0.9,
    min_df=1,
    coverage_threshold=0.10,
    use_stopwords=True  
)

print("\n--- TF-IDF EVAL ---")
for k, v in metrics.items():
    if k == "top_keywords":
        print(f"{k}: {[w for w,_ in v]}")
    else:
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

with open("outputs/tfidf_eval.json", "w", encoding="utf-8") as f:
    json.dump(metrics, f, ensure_ascii=False, indent=2)
print("Saved TF-IDF metrics to outputs/tfidf_eval.json")


--- TF-IDF EVAL ---
tfidf_cosine_similarity: 0.2578
coverage@0.10: 1.0000
redundancy_avg_pairwise: 0.0188
top_keywords: ['sheep', 'group', 'airway', 'reactive', 'stable', 'sheep model', 'sheep groups', 'sheep ascaris', 'saline group', 'saline']
notes: Higher similarity & coverage are good; lower redundancy is better.
Saved TF-IDF metrics to outputs/tfidf_eval.json


In [None]:
# ============================ ROUGE evaluation ==============================
def _find_ref_col(df_refs: pd.DataFrame) -> str:
    for c in ["HUMAN_SUMMARY", "SUMMARY", "reference", "REF"]:
        if c in df_refs.columns:
            return c
    raise ValueError(
        "Could not find a reference-summary column in HUMAN_CSV. "
        "Expected one of: HUMAN_SUMMARY, SUMMARY, reference, REF."
    )

def _to_plain_text(t: str) -> str:
    t = strip_all_think_blocks(str(t))
    t = re.sub(r"</?summary>", "", t, flags=re.I)
    return t.strip()

def _load_human_reference(human_csv_path: str, abstract_id):
    hdf = pd.read_csv(human_csv_path)
    ref_col = _find_ref_col(hdf)

    if abstract_id is not None and "ABSTRACT_ID" in hdf.columns:
        hit = hdf.loc[hdf["ABSTRACT_ID"] == abstract_id]
        if not hit.empty:
            refs = [str(r).strip() for r in hit[ref_col].dropna().tolist() if str(r).strip()]
            if refs:
                return refs

    refs = [str(r).strip() for r in hdf[ref_col].dropna().tolist() if str(r).strip()]
    if not refs:
        raise ValueError("No non-empty human reference summaries found in HUMAN_CSV.")
    return refs

print("\n--- ROUGE EVAL ---")
try:
    references = _load_human_reference(HUMAN_CSV, abs_id)
    pred = _to_plain_text(summary)
    refs_plain = [_to_plain_text(r) for r in references]

    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeLsum"], use_stemmer=True)

    if len(refs_plain) > 1:
        s = scorer.score_multi(refs_plain, pred)  
    else:
        s = scorer.score(refs_plain[0], pred)

    per_sample_rows = []
    def _row_from_score(tag, scr):
        return {
            "metric": tag,
            "precision": float(scr.precision),
            "recall": float(scr.recall),
            "f1": float(scr.fmeasure),
        }
    for tag, scr in s.items():
        per_sample_rows.append(_row_from_score(tag, scr))

    agg = scoring.BootstrapAggregator()
    agg.add_scores(s)
    agg_res = agg.aggregate()

    for tag in ["rouge1", "rouge2", "rougeLsum"]:
        mid = agg_res[tag].mid
        print(f"{tag.upper():9s} F1={mid.fmeasure:.4f}")

    with open(os.path.join(OUT_DIR, "rouge_per_sample.json"), "w", encoding="utf-8") as f:
        json.dump(per_sample_rows, f, ensure_ascii=False, indent=2)

    agg_out = {
        k: {
            "precision": float(v.mid.precision),
            "recall": float(v.mid.recall),
            "f1": float(v.mid.fmeasure),
            "low_f1": float(v.low.fmeasure),
            "high_f1": float(v.high.fmeasure),
        }
        for k, v in agg_res.items()
    }
    with open(os.path.join(OUT_DIR, "rouge_aggregate.json"), "w", encoding="utf-8") as f:
        json.dump(agg_out, f, ensure_ascii=False, indent=2)

    print("Saved ROUGE results to outputs/rouge_per_sample.json and outputs/rouge_aggregate.json")

except Exception as e:
    print("ROUGE evaluation error:", e)


--- ROUGE EVAL ---
ROUGE1    F1=0.4714
ROUGE2    F1=0.1014
ROUGELSUM F1=0.2429
Saved ROUGE results to outputs/rouge_per_sample.json and outputs/rouge_aggregate.json
