In [None]:
# -*- coding: utf-8 -*-
"""
Model interpretability via attention attribution (post-hoc, no retraining).

For a (fine-tuned) task-specific model:
1) Extract the final-layer self-attention matrix during inference.
2) Average attention across heads.
3) Use last-token attention (attention distribution of the final query token).
4) Normalize token attention weights so they sum to one over the input sequence.
5) Aggregate token-level attention into three semantically defined prompt components:
   - Date   : year+month
   - Disease
   - Outcome: cases vs deaths

"""

import re
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

try:
    from peft import PeftModel
except Exception:
    PeftModel = None


# =============================
# Settings (keep key parameters unchanged)
# =============================
BASE_MODEL_PATH = "path/to/your/base_model"        
LORA_ADAPTER_PATH = None                           

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

MAX_LEN = 128

# Attention mode (no retraining)
# "last"  : last token row (recommended)
# "lastk" : mean of last K rows
# "mean"  : mean over all query rows (legacy robustness check)
QUERY_MODE = "last"
LAST_K = 4

DEMO_PROMPTS = [
    "2009年1月 肺结核 发病数",
    "2016-07 Influenza-associated pediatric mortality deaths",
    "2020/12 Dengue Fever cases",
]


# =============================
# Attention backend: force eager 
# =============================
def disable_sdpa_flash() -> None:
    try:
        torch.backends.cuda.enable_flash_sdp(False)
    except Exception:
        pass
    try:
        torch.backends.cuda.enable_mem_efficient_sdp(False)
    except Exception:
        pass
    try:
        torch.backends.cuda.enable_math_sdp(True)
    except Exception:
        pass


def load_tokenizer_fast_or_fail(model_path: str):
    tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
    if not getattr(tok, "is_fast", False):
        raise RuntimeError("Fast tokenizer is required for offset_mapping.")
    return tok


def load_base_model_eager(model_path: str):
    disable_sdpa_flash()
    try:
        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=dtype,
            attn_implementation="eager",
        )
    except TypeError:
        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype=dtype,
        )
        for obj in [model, getattr(model, "base_model", None), getattr(model, "model", None)]:
            if obj is None:
                continue
            try:
                obj.set_attn_implementation("eager")
            except Exception:
                pass

    model.to(device)
    model.eval()
    return model


def maybe_load_lora(model, adapter_path: str | None):
    if adapter_path is None:
        return model
    if PeftModel is None:
        raise RuntimeError("peft is not available. Install peft to load LoRA adapters.")
    model_peft = PeftModel.from_pretrained(model, adapter_path)
    model_peft.to(device)
    model_peft.eval()
    return model_peft


# =============================
# Component parsing
# =============================
_OUT_PAT = re.compile(r"(发病数|死亡数|cases|case|deaths|death)", flags=re.IGNORECASE)


def _trim_span(s: str, a: int, b: int) -> tuple[int, int]:
    a = max(a, 0)
    b = min(b, len(s))
    while a < b and s[a] in " \t\r\n:：,，;；-—_()（）[]【】":
        a += 1
    while b > a and s[b - 1] in " \t\r\n:：,，;；-—_()（）[]【】":
        b -= 1
    return a, b


def parse_components(prompt: str) -> dict[str, tuple[int, int]]:
    """
    Return character spans for three components in the prompt:
      - date    : last occurrence of YYYY年M月 or YYYY-M or YYYY/M
      - outcome : last occurrence of outcome keywords
      - disease : span between date end and outcome start
    """
    s = prompt
    L = len(s)

    date_span = (0, 0)
    date_end = None

    mlist = list(re.finditer(r"(?P<year>\d{4})年(?P<month>\d{1,2})月", s))
    if mlist:
        m = mlist[-1]
        date_span = (m.start(), m.end())
        date_end = m.end()
    else:
        mlist2 = list(re.finditer(r"(?P<year>\d{4})\s*[-/]\s*(?P<month>\d{1,2})", s))
        if mlist2:
            m = mlist2[-1]
            date_span = (m.start(), m.end())
            date_end = m.end()

    if date_end is None:
        date_end = 0

    out_span = (L, L)
    om = list(_OUT_PAT.finditer(s))
    if om:
        m2 = om[-1]
        out_span = (m2.start(), m2.end())

    d0 = date_end
    d1 = out_span[0] if out_span[0] > d0 else L
    d0, d1 = _trim_span(s, d0, d1)

    return {"date": date_span, "disease": (d0, d1), "outcome": out_span}


def overlap_len(a: tuple[int, int], b: tuple[int, int]) -> int:
    return max(0, min(a[1], b[1]) - max(a[0], b[0]))


def assign_component(tok_span: tuple[int, int], comp_spans: dict[str, tuple[int, int]]) -> str:
    best_k, best_ol = "disease", 0
    for k, sp in comp_spans.items():
        ol = overlap_len(tok_span, sp)
        if ol > best_ol:
            best_k, best_ol = k, ol
    return best_k


# =============================
# Attention extraction + aggregation
# =============================
@torch.no_grad()
def token_attention_profile(model, tokenizer, prompt: str, max_len: int = 128):
    enc = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_len,
        return_offsets_mapping=True,
    )
    if "offset_mapping" not in enc:
        raise RuntimeError("offset_mapping missing (fast tokenizer required).")

    offsets = enc["offset_mapping"][0].tolist()
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_attentions=True,
        use_cache=False,
        return_dict=True,
    )
    if out.attentions is None:
        raise RuntimeError("attentions=None. Ensure eager attention (no sdpa/flash).")

    L = int(attention_mask[0].sum().item())
    offsets = offsets[:L]

    # Final layer, average across heads -> (L, L)
    A = out.attentions[-1][0].float().mean(dim=0)[:L, :L]

    if QUERY_MODE == "last":
        prof = A[L - 1, :].float().cpu().numpy()
    elif QUERY_MODE == "lastk":
        k0 = max(0, L - int(LAST_K))
        prof = A[k0:L, :].mean(dim=0).float().cpu().numpy()
    elif QUERY_MODE == "mean":
        prof = A.mean(dim=0).float().cpu().numpy()
    else:
        raise ValueError(f"Unknown QUERY_MODE={QUERY_MODE}")

    prof = np.maximum(prof, 0.0)
    if prof.sum() > 0:
        prof = prof / prof.sum()

    return offsets, prof


def aggregate_attention_to_components(prompt: str, offsets, attn_vec: np.ndarray) -> dict[str, float]:
    comp_spans = parse_components(prompt)
    weights = {"date": 0.0, "disease": 0.0, "outcome": 0.0}

    for tok_span, w in zip(offsets, attn_vec):
        a, b = int(tok_span[0]), int(tok_span[1])
        if a == 0 and b == 0:
            continue
        if b <= a:
            continue
        comp = assign_component((a, b), comp_spans)
        if comp in weights:
            weights[comp] += float(w)

    s = sum(weights.values())
    if s > 0:
        for k in weights:
            weights[k] /= s
    return weights


def main():
    tokenizer = load_tokenizer_fast_or_fail(BASE_MODEL_PATH)
    base_model = load_base_model_eager(BASE_MODEL_PATH)
    model = maybe_load_lora(base_model, LORA_ADAPTER_PATH)

    for i, prompt in enumerate(DEMO_PROMPTS, start=1):
        offsets, attn_prof = token_attention_profile(model, tokenizer, prompt, MAX_LEN)
        comp_w = aggregate_attention_to_components(prompt, offsets, attn_prof)
        print(f"[{i}] prompt={prompt!r}")
        print(f"    date={comp_w['date']:.6f}  disease={comp_w['disease']:.6f}  outcome={comp_w['outcome']:.6f}")


if __name__ == "__main__":
    main()
