In [72]:
#sudo apt update
#sudo apt install -y python3 python3-venv python3-pip
#python3 -V                    # sanity check
#python3 -m venv .venv
#source .venv/bin/activate
#python -m pip install --upgrade pip
#pip install torch transformers evaluate rouge-score pandas
#python -m pip install --upgrade pip setuptools wheel
#pip install "bitsandbytes>=0.43.3" accelerate

In [73]:
import os, re, pandas as pd, torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Mitigate CUDA allocator fragmentation (set before CUDA allocations)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

print("CUDA available:", torch.cuda.is_available())


CUDA available: True


In [74]:
MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507"
CSV_PATH = "data/MeDAL/pretrain_subset/test.csv" 


In [75]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,   
)


In [76]:
print("Loading tokenizer & model...")

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if tok.pad_token_id is None and tok.eos_token_id is not None:
    tok.pad_token = tok.eos_token

if torch.cuda.is_available():
    # GPU path: 4-bit quant + automatic device placement & offload
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",                
        torch_dtype="auto",
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        offload_folder="offload",
        max_memory={0: "30GiB", "cpu": "64GiB"},  
    )
else:
    # CPU fallback (very slow for a 30B model, but keeps code error-free)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype="auto",
        low_cpu_mem_usage=True,
        offload_state_dict=True,
        offload_folder="offload",
        max_memory={"cpu": "64GiB"},
    )

model.eval()
print("Model loaded successfully!")


Loading tokenizer & model...


Loading checkpoint shards: 100%|██████████| 16/16 [02:24<00:00,  9.06s/it]


Model loaded successfully!


In [77]:
def build_inputs(tok, system_msg, user_text):
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_text},
    ]
    chat = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return messages, chat


def safe_trim_to_first_n_sentences(text, n=4, min_keep=3):
    text = text.replace("<think>", "").replace("</think>", "").strip()
    text = re.sub(r"^\s*(?:Hmm[.,].*?\n+)+", "", text, flags=re.IGNORECASE | re.DOTALL).strip()
    sents = re.split(r"(?<=[.!?])\s+", text)
    if len(sents) > n:
        sents = sents[:n]
    if len(sents) < min_keep and len(text) > 0:
        return text
    return " ".join(sents).strip()


def postprocess(decoded: str) -> str:
    # Keep only the text after the LAST </think>
    if "</think>" in decoded:
        decoded = decoded.split("</think>")[-1]
    # Extract <summary>...</summary> if present
    m = re.search(r"<summary>(.*?)</summary>", decoded, flags=re.DOTALL | re.IGNORECASE)
    final = (m.group(1) if m else decoded).strip()
    # Normalize empty/ellipsis
    if final.strip() in {"", "...", "…"}:
        return ""
    return safe_trim_to_first_n_sentences(final, n=4, min_keep=3)


In [78]:
from transformers import StoppingCriteria, StoppingCriteriaList
import torch

def get_embedding_device(model):
    try:
        return model.get_input_embeddings().weight.device
    except Exception:
        return next(model.parameters()).device

EMBED_DEVICE = get_embedding_device(model)
print("Embedding device:", EMBED_DEVICE)

class StopOnSubstrings(StoppingCriteria):
    """
    Stop generation once any of the provided stop strings appears at the end.
    Compares on token space for reliability.
    """
    def __init__(self, stop_strings, tokenizer):
        super().__init__()
        self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for s_ids in self.stop_ids:
            if len(input_ids[0]) >= len(s_ids) and input_ids[0].tolist()[-len(s_ids):] == s_ids:
                return True
        return False

def make_bad_words_ids(tokenizer, bad_list):
    ids = []
    for w in bad_list:
        wid = tokenizer.encode(w, add_special_tokens=False)
        if wid:
            ids.append(wid)
    return ids

STOP_STRINGS = ["</summary>"]
BAD_WORDS = ["I ", "I'", "I'm", "I’m", "I need", "I will", "I’ll", "I think", "user", "preface", "analysis"]

STOPPER = StoppingCriteriaList([StopOnSubstrings(STOP_STRINGS, tok)])
BAD_WORDS_IDS = make_bad_words_ids(tok, BAD_WORDS)

def generate_summary(model, tok, chat, system_msg_for_rebuild,
                     keep_tokens_for_answer=512, source_text=None,
                     gen_temp=0.4, gen_top_p=0.9):
    # Budget prompt
    max_ctx = int(getattr(getattr(model, "config", object()), "max_position_embeddings", 32768))
    reserve = keep_tokens_for_answer + 64
    prompt_budget = max(256, max_ctx - reserve)

    # Tokenize + move to embedding device
    inputs = tok([chat], return_tensors="pt", truncation=True, max_length=prompt_budget)
    inputs = {k: v.to(EMBED_DEVICE) for k, v in inputs.items()}

    # Rebuild from head if still too long
    if inputs["input_ids"].shape[1] >= prompt_budget and source_text:
        head_enc = tok(source_text, return_tensors="pt", truncation=True, max_length=prompt_budget)
        trimmed = tok.decode(head_enc["input_ids"][0], skip_special_tokens=True)
        _, chat2 = build_inputs(tok, system_msg_for_rebuild, trimmed)
        inputs = tok([chat2], return_tensors="pt", truncation=True, max_length=prompt_budget)
        inputs = {k: v.to(EMBED_DEVICE) for k, v in inputs.items()}

    # Ensure pad token
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token = tok.eos_token

    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=320,
            do_sample=True,
            temperature=gen_temp,
            top_p=gen_top_p,
            stopping_criteria=STOPPER,      # <- stop exactly at </summary>
            bad_words_ids=BAD_WORDS_IDS,    # <- lightly discourage meta phrases
            no_repeat_ngram_size=3,
        )

    # Decode only the newly generated tokens
    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    decoded = tok.decode(new_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False)
    return decoded


Embedding device: cuda:0


In [79]:
try:
    df = pd.read_csv(CSV_PATH)
    source_text = str(df.iloc[0]["TEXT"]).strip()
    if not source_text:
        raise ValueError("Empty TEXT cell in CSV.")
    print("Loaded CSV sample. Characters:", len(source_text))
except Exception as e:
    print("CSV load warning:", e)
    # Fallback text so the notebook runs without errors if CSV is missing
    source_text = (
        "Background: Hypertension is a common cardiovascular risk factor. "
        "Methods: We conducted a randomized, controlled trial evaluating a new ACE inhibitor versus placebo "
        "in 1,200 adults with stage 2 hypertension over 24 weeks. "
        "Results: The treatment group showed a mean systolic BP reduction of 18 mmHg versus 6 mmHg with placebo; "
        "adverse events were mild and comparable. "
        "Conclusion: The ACE inhibitor significantly reduced blood pressure with acceptable safety."
    )
    print("Using fallback abstract. Characters:", len(source_text))


Loaded CSV sample. Characters: 1039


In [None]:
SYSTEM_MSG_TAGGED = (
    "Summarize the user's medical abstract in 3–4 sentences. "
    "Be clear and factual. Keep key clinical details (condition, intervention, measurements, outcomes). "
    "Return ONLY the summary wrapped exactly as:\n<summary>...</summary>\nNo preface, no analysis, no extra text."
)

SYSTEM_MSG_FALLBACK = (
    "Summarize the user's medical abstract in 3–4 sentences. "
    "Be clear and factual. Keep key clinical details (condition, intervention, measurements, outcomes). "
    "Output ONLY the 3–4 sentence summary—no preface, no analysis."
)

GEN_TEMP, GEN_TOPP = 0.4, 0.9


In [81]:
# Attempt 1: tagged prompt that must end with </summary>
_, chat = build_inputs(tok, SYSTEM_MSG_TAGGED, source_text)
print("Generating summary (attempt 1)...")
decoded = generate_summary(
    model, tok, chat, system_msg_for_rebuild=SYSTEM_MSG_TAGGED,
    keep_tokens_for_answer=512, source_text=source_text,
    gen_temp=GEN_TEMP, gen_top_p=GEN_TOPP
)
summary = postprocess(decoded)

# Fallback if empty
if not summary:
    print("Generating summary (attempt 2, fallback)...")
    _, chat_fb = build_inputs(tok, SYSTEM_MSG_FALLBACK, source_text)
    decoded_fb = generate_summary(
        model, tok, chat_fb, system_msg_for_rebuild=SYSTEM_MSG_FALLBACK,
        keep_tokens_for_answer=512, source_text=source_text,
        gen_temp=0.7, gen_top_p=0.95
    )
    summary = postprocess(decoded_fb)
    

# Final guardrail: if the model STILL tries meta commentary, drop lines with “I …”
if summary and summary[:1].lower() in {"i"}:
    summary = re.sub(r"(^|\n)I[^\n]*", "", summary).strip()

if not summary:
    preview = (decoded or "")[:400].replace("\n", " ")
    print("\n[Debug] Model raw (first 400 chars):", preview)
    summary = "Summary unavailable: the model did not produce a clean summary."

print("\n--- SUMMARY ---\n")
print(summary)


Generating summary (attempt 1)...

--- SUMMARY ---

Looking at this abstract, it describes a sheep model study where they exposed sheep to Ascaris allergen. The key points I need to extract: they created a chronic allergic model by repeated tracheally instilling Ascarist antigen until reactive sheep (Group C) showed stable 3x increase in airway resistance (RL) compared to controls. They compared Group C (reactive, n=6) to Group B (non-reactive, 3 sheep) and Group A (saline control, 8 sheep). Measurements included RL, FRC via two techniques, lung mechanics, and BAL cell counts.


In [82]:
os.makedirs("outputs", exist_ok=True)
out_path = "outputs/summary_first_test.txt"
with open(out_path, "w", encoding="utf-8") as f:
    f.write(summary + "\n")
print(f"\nSaved to {out_path}")



Saved to outputs/summary_first_test.txt
