In [None]:
# -*- coding: utf-8 -*-
"""
Zero-shot baseline (pretraining-contamination sensitivity analysis)

What this script does
---------------------
This script implements a **zero-shot** baseline using a **frozen pretrained LLM**
(e.g., Qwen-2.5-3B) with **NO task-specific fine-tuning**, to test whether near-correct
test-period surveillance values could plausibly be explained by **memorization during
pretraining**.

Per disease–outcome task, it:
  - sorts chronologically
  - splits into train/val/test = 60%/20%/20%
  - queries the pretrained-only model **ONLY on the test partition**
  - enforces a strict prompt that requests a **one-line JSON** output: {"value": <integer>}
  - decodes **generated tokens only** (to avoid parsing numbers from the prompt)
  - extracts the first valid numeric value using a rule-based parser
  - computes MAE/MSE on the **raw count scale**
  - writes per-task Excel files + a dataset-level summary Excel

"""

import os
import re
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime, timedelta

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ================== Global config ==================
random_seed = 3407
cuda_device = "cuda:0" if torch.cuda.is_available() else "cpu"

pretrained_path = os.environ.get("QWEN25_3B_PATH", r"E:\Qwen2.5-3B")

# ---------- generation hyperparams ----------
GEN_BATCH_SIZE = 8
MAX_NEW_TOKENS = 24
TEMPERATURE = 0.0
TOP_P = 1.0
NUM_BEAMS = 1

# ---------- prompt template ----------

PROMPT_TEMPLATE = (
    "你是一个数值回归器。请给出该指标对应的【数值】。\n"
    "输出必须严格为一行JSON：{{\"value\": <integer>}}\n"
    "要求：\n"
    "1) <integer> 必须是非负整数（不能有小数点）\n"
    "2) 不要输出年份、月份、日期，不要输出“年”“月”等字\n"
    "3) 除这一行JSON外不要输出任何其它文字\n"
    "指标：{instruction}\n"
    "输出："
)

# ---------- output dirs ----------
OUT_ROOT = "./zero_shot_outputs"
OUT_CHINA_DIR = os.path.join(OUT_ROOT, "CHINA")
OUT_USAUS_DIR = os.path.join(OUT_ROOT, "USAUS")
os.makedirs(OUT_CHINA_DIR, exist_ok=True)
os.makedirs(OUT_USAUS_DIR, exist_ok=True)

# ================== Utils ==================
def set_all_seeds(seed: int = 3407):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def slugify(text):
    s = str(text)
    s = re.sub(r'[\\/:*?"<>|]', "_", s)
    s = re.sub(r"\s+", "_", s.strip())
    return s[:100]

def extract_date_from_instruction(text: str):
    """Extract (year, month) from text like '2009年1月...发病数'."""
    m = re.search(r"(\d{4})年(\d{1,2})月", str(text))
    if not m:
        raise ValueError(f"Date not found in text: {text}")
    return int(m.group(1)), int(m.group(2))

def convert_date_to_ym(raw_date):
    """Convert Excel date to (year, month, date_str like '2009年1月')."""
    if isinstance(raw_date, (int, float)):
        if pd.isna(raw_date):
            return (None, None, "未知日期")
        base = datetime(1899, 12, 30)
        real_date = base + timedelta(days=float(raw_date))
        return (real_date.year, real_date.month, f"{real_date.year}年{real_date.month}月")
    if isinstance(raw_date, datetime):
        return (raw_date.year, raw_date.month, f"{raw_date.year}年{raw_date.month}月")
    s = str(raw_date).strip()
    if "年" in s and "月" in s:
        y, m = extract_date_from_instruction(s + "占位")  # reuse regex
        return (y, m, s)
    try:
        dt = pd.to_datetime(s)
        return (dt.year, dt.month, f"{dt.year}年{dt.month}月")
    except:
        return (None, None, "未知日期")

NUM_RE = r"[-+]?\d[\d,]*\.?\d*(?:[eE][-+]?\d+)?"

def _to_float(x: str):
    try:
        return float(x.replace(",", ""))
    except:
        return np.nan

def parse_prediction(gen_text: str):
    """
    Return: (pred_value, parse_rule)
    Rules (in order):
      1) JSON: {"value": number}
      2) formatted_number / value = number
      3) after 数值/预测值/预测：number
      4) fallback: remove date like 2021年11月 then take LAST number
    """
    if gen_text is None:
        return np.nan, "none"

    s = str(gen_text).strip()

    # 1) JSON value
    m = re.search(r'"value"\s*:\s*(' + NUM_RE + r')', s)
    if m:
        return _to_float(m.group(1)), "json_value"

    # 2) formatted_number= / value=
    m = re.search(r'(?:formatted_number|value)\s*=\s*"?(' + NUM_RE + r')', s)
    if m:
        return _to_float(m.group(1)), "key_equals"

    # 3) after 数值/预测值/预测:
    m = re.search(r'(?:数值|预测值|预测)\s*[:：]\s*(' + NUM_RE + r')', s)
    if m:
        return _to_float(m.group(1)), "after_colon"

    # 4) fallback: remove date patterns then take LAST number
    s2 = re.sub(r'\d{4}年\d{1,2}月', ' ', s)   
    s2 = re.sub(r'\d{4}-\d{1,2}', ' ', s2)    
    nums = re.findall(NUM_RE, s2)
    if nums:
        return _to_float(nums[-1]), "last_number_no_date"

    return np.nan, "no_number"

# ================== Data loaders ==================
def load_china_wide_to_long(excel_path: str):
    """
    Read China wide Excel with columns: 指标, 日期, (disease cols...)
    Output long df with columns:
      - instruction, output, year, month, disease, measure
    """
    df_excel = pd.read_excel(excel_path)
    col_names = df_excel.columns.tolist()
    disease_cols = col_names[2:]

    rows = []
    for _, row in df_excel.iterrows():
        measure = str(row["指标"]).strip()  # 发病数 / 死亡数
        raw_date = row["日期"]
        year, month, date_str = convert_date_to_ym(raw_date)

        for disease in disease_cols:
            val = row[disease]
            if pd.isna(val):
                val = 0.0
            disease_str = str(disease).strip()
            instruction = f"{date_str}{disease_str}{measure}"
            rows.append({
                "instruction": instruction,
                "output": float(val),
                "year": year,
                "month": month,
                "disease": disease_str,
                "measure": measure,
            })
    df = pd.DataFrame(rows)
    # drop unknown-date rows if any
    df = df.dropna(subset=["year", "month"]).copy()
    df["year"] = df["year"].astype(int)
    df["month"] = df["month"].astype(int)
    return df

def load_usaus_long(excel_path: str):
    """
    Read US/AUS long Excel which should already have:
      - instruction, output, disease, year, month
    """
    df = pd.read_excel(excel_path)
    if "output" not in df.columns:
        raise ValueError("US/AUS file must contain column 'output'.")
    df["output"] = df["output"].astype(float)

    for c in ["instruction", "disease", "year", "month"]:
        if c not in df.columns:
            raise ValueError(f"US/AUS file missing required column: {c}")

    df["year"] = df["year"].astype(int)
    df["month"] = df["month"].astype(int)
    return df

# ================== Model (base zero-shot) ==================
def load_base_llm(pretrained_path: str):
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path, trust_remote_code=True)
    # Decoder-only models: left padding is safer for batch generation
    tokenizer.padding_side = "left"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_path,
        torch_dtype=torch.float16,
        device_map={"": 0} if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    model.eval()
    return tokenizer, model

@torch.no_grad()
def generate_numbers(model, tokenizer, instructions, batch_size=8):
    """
    Generate for a batch of instructions, return:
      - gen_texts: generated-only strings
      - preds: parsed floats (np.nan if parse fails)
      - rules_all: which parser rule matched
    """
    rules_all = []
    prompts = [PROMPT_TEMPLATE.format(instruction=ins) for ins in instructions]

    gen_texts_all = []
    preds_all = []

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        enc = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=256
        )
        input_ids = enc["input_ids"].to(cuda_device)
        attention_mask = enc["attention_mask"].to(cuda_device)

        # per-sample prompt lengths (without padding)
        prompt_lens = attention_mask.sum(dim=1).tolist()

        gen_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            num_beams=NUM_BEAMS,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

        # decode only the generated part (avoid parsing digits from prompt)
        for j in range(gen_ids.size(0)):
            full = gen_ids[j]
            gen_part = full[prompt_lens[j]:]  # cut prompt tokens
            gen_text = tokenizer.decode(gen_part, skip_special_tokens=True).strip()
            gen_texts_all.append(gen_text)
            val, rule = parse_prediction(gen_text)
            preds_all.append(val)
            rules_all.append(rule)

    return gen_texts_all, np.array(preds_all, dtype=float), rules_all

# ================== Core runner ==================
def run_zero_shot_for_dataset(df, dataset_name, out_dir, tasks):
    """
    df must contain:
      - instruction, output, disease, year, month
    tasks:
      - for CHINA: ["发病数","死亡数"] (stored in df['measure'])
      - for USAUS: ["发病数"] (no measure column required; we just label it)
    """
    tokenizer, model = load_base_llm(pretrained_path)

    summary_rows = []
    all_diseases = sorted(df["disease"].unique().tolist())

    for disease in all_diseases:
        for task in tasks:
            if dataset_name == "CHINA":
                sub = df[(df["disease"] == disease) & (df["measure"] == task)].copy()
            else:
                # USAUS: single task label
                sub = df[df["disease"] == disease].copy()

            if sub.empty:
                continue

            # sort by time
            sub = sub.sort_values(["year", "month"]).reset_index(drop=True)

            N = len(sub)
            if N < 5:
                continue

            train_size = int(0.6 * N)
            valid_size = int(0.2 * N)
            test_size = N - train_size - valid_size

            test_sub = sub.iloc[train_size + valid_size:].copy()
            if test_sub.empty:
                continue

            test_instructions = test_sub["instruction"].tolist()
            test_actual = test_sub["output"].to_numpy(dtype=float)
            test_dates = [datetime(int(y), int(m), 1) for y, m in zip(test_sub["year"], test_sub["month"])]

            # -------- zero-shot generate --------
            gen_texts, preds, parse_rules = generate_numbers(
                model, tokenizer, test_instructions, batch_size=GEN_BATCH_SIZE
            )

            # metrics on parsed subset
            ok = ~np.isnan(preds)
            n_test = len(preds)
            n_ok = int(ok.sum())
            parse_rate = n_ok / n_test if n_test > 0 else 0.0

            if n_ok > 0:
                mse = float(np.mean((preds[ok] - test_actual[ok]) ** 2))
                mae = float(np.mean(np.abs(preds[ok] - test_actual[ok])))
            else:
                mse, mae = np.nan, np.nan

            # save per-task details
            safe_disease = slugify(disease)
            safe_task = slugify(task)

            detail_df = pd.DataFrame({
                "Date": test_dates,
                "Instruction": test_instructions,
                "Actual": test_actual,
                "Generated_Text": gen_texts,      # raw generated-only text
                "Predicted_Parsed": preds,
                "Parse_rule": parse_rules         # which parsing rule extracted the number
            })

            fn = (
                f"{safe_disease}_{safe_task}_zeroshoot.xlsx"
                if dataset_name == "CHINA"
                else f"{safe_disease}_zeroshoot.xlsx"
            )
            detail_path = os.path.join(out_dir, fn)
            detail_df.to_excel(detail_path, index=False)

            summary_rows.append({
                "Dataset": dataset_name,
                "Disease": disease,
                "Task": task,
                "N_total_test": n_test,
                "N_parsed": n_ok,
                "Parse_rate": round(parse_rate, 4),
                "MSE": mse,
                "MAE": mae,
                "Detail_file": detail_path
            })

            print(
                f"[{dataset_name}] {disease} {task} | test={n_test} parsed={n_ok} "
                f"rate={parse_rate:.2%} | MSE={mse:.4f} MAE={mae:.4f} -> {detail_path}"
            )

    summary_df = pd.DataFrame(summary_rows)

    # overall aggregation
    if not summary_df.empty:
        overall = summary_df.copy()

        overall_mse_mean = np.nanmean(overall["MSE"].values) if overall["MSE"].notna().any() else np.nan
        overall_mae_mean = np.nanmean(overall["MAE"].values) if overall["MAE"].notna().any() else np.nan
        overall_parse = np.nanmean(overall["Parse_rate"].values) if overall["Parse_rate"].notna().any() else np.nan

        # weighted by N_parsed
        w = overall["N_parsed"].to_numpy(dtype=float)
        mse_w = overall["MSE"].to_numpy(dtype=float)
        mae_w = overall["MAE"].to_numpy(dtype=float)
        if np.nansum(w) > 0:
            overall_mse_w = np.nansum(mse_w * w) / np.nansum(w)
            overall_mae_w = np.nansum(mae_w * w) / np.nansum(w)
        else:
            overall_mse_w, overall_mae_w = np.nan, np.nan

        overall_row = pd.DataFrame([{
            "Dataset": dataset_name,
            "Disease": "__OVERALL__",
            "Task": "__OVERALL__",
            "N_total_test": int(overall["N_total_test"].sum()),
            "N_parsed": int(overall["N_parsed"].sum()),
            "Parse_rate": round(float(overall_parse), 4) if not np.isnan(overall_parse) else np.nan,
            "MSE": overall_mse_mean,
            "MAE": overall_mae_mean,
            "Detail_file": ""
        }])
        summary_df = pd.concat([summary_df, overall_row], ignore_index=True)

        overall_row_w = pd.DataFrame([{
            "Dataset": dataset_name,
            "Disease": "__OVERALL_WEIGHTED__",
            "Task": "__OVERALL_WEIGHTED__",
            "N_total_test": int(overall["N_total_test"].sum()),
            "N_parsed": int(overall["N_parsed"].sum()),
            "Parse_rate": round(float(overall_parse), 4) if not np.isnan(overall_parse) else np.nan,
            "MSE": overall_mse_w,
            "MAE": overall_mae_w,
            "Detail_file": ""
        }])
        summary_df = pd.concat([summary_df, overall_row_w], ignore_index=True)

    summary_path = os.path.join(out_dir, f"{dataset_name}_zero_shot_summary.xlsx")
    summary_df.to_excel(summary_path, index=False)
    print(f"\n[{dataset_name}] Summary saved -> {summary_path}\n")
    return summary_df, summary_path

# ================== MAIN ==================
def main():
    set_all_seeds(random_seed)

    china_excel_path = "example_china_wide.xlsx"
    usaus_excel_path = "example_usaus_long.xlsx"

    # ---- CHINA ----
    if os.path.exists(china_excel_path):
        df_china = load_china_wide_to_long(china_excel_path)
        # tasks in China file (e.g., 发病数 / 死亡数)
        china_tasks = sorted(df_china["measure"].unique().tolist())
        run_zero_shot_for_dataset(
            df=df_china,
            dataset_name="CHINA",
            out_dir=OUT_CHINA_DIR,
            tasks=china_tasks
        )
    else:
        print(f"[CHINA] File not found: {china_excel_path}")

    # ---- USAUS ----
    if os.path.exists(usaus_excel_path):
        df_usaus = load_usaus_long(usaus_excel_path)
        # only one task label
        usaus_tasks = ["发病数"]
        run_zero_shot_for_dataset(
            df=df_usaus,
            dataset_name="USAUS",
            out_dir=OUT_USAUS_DIR,
            tasks=usaus_tasks
        )
    else:
        print(f"[USAUS] File not found: {usaus_excel_path}")

if __name__ == "__main__":
    main()
