In [None]:
# -*- coding: utf-8 -*-
"""
LLM-NID Forecasting
================================================
This script implements the *public* trunk of the infectious-disease forecasting pipeline.
It fine-tunes a LoRA-adapted Qwen-2.5-3B model on a single (disease, outcome) time series and
performs evaluation on a 60/20/20 split. 
"""


import random
import pandas as pd
import numpy as np
from tqdm import tqdm
from copy import deepcopy
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
from datetime import datetime, timedelta
import re


# ================== Hyperparameters  ==================
random_seed = 3407

pretrained_path = "/path/to/Qwen2.5-3B"

dropout_rate = 0.2
max_epoch_num = 20  # number of training epochs
cuda_device = "cuda:0" if torch.cuda.is_available() else "cpu"

early_stop_patience = 3
best_valid_loss = float("inf")
no_improve_epochs = 0

excel_path = os.path.join(".", "data", "example_input.xlsx")


# ================== Step 1: Read and convert Excel to unified DataFrame ==================
def process_excel_to_df(excel_file):
    """
    Read Excel and convert it into a DataFrame with columns:
      - instruction: e.g., "2009年1月鼠疫发病数"
      - output: numeric target

    Excel schema assumptions :
      - First two columns are: 指标, 日期
      - Disease columns start from the 3rd column
    """
    df_excel = pd.read_excel(excel_file)
    col_names = df_excel.columns.tolist()

    # First two columns are ["指标", "日期"]; disease columns begin at the 3rd column
    disease_cols = col_names[2:]

    data_list = []
    for _, row in df_excel.iterrows():
        measure = str(row["指标"]).strip()  # "发病数" / "死亡数"
        raw_date = row["日期"]

        # Convert date to "yyyy年m月"
        date_str = convert_date_to_ym_str(raw_date)

        for disease in disease_cols:
            value = row[disease]
            if pd.isna(value):
                value = 0.0  

            disease_str = str(disease).strip()
            # Concatenate into "2009年1月鼠疫发病数"
            instruction = f"{date_str}{disease_str}{measure}"
            data_list.append({"instruction": instruction, "output": value})

    df_out = pd.DataFrame(data_list)
    return df_out


def convert_date_to_ym_str(raw_date):
    """
    Convert an Excel date (could be numeric serial / datetime / string) to "yyyy年m月".
    """
    if isinstance(raw_date, (int, float)):
        if pd.isna(raw_date):
            return "UnknownDate"
        base = datetime(1899, 12, 30)
        delta = timedelta(days=raw_date)
        real_date = base + delta
        return f"{real_date.year}年{real_date.month}月"
    elif isinstance(raw_date, datetime):
        return f"{raw_date.year}年{raw_date.month}月"
    else:
        s = str(raw_date).strip()
        if "年" in s and "月" in s:
            return s
        try:
            dt = pd.to_datetime(s)
            return f"{dt.year}年{dt.month}月"
        except Exception:
            return "UnknownDate"


# Read and convert into the unified format
df = process_excel_to_df(excel_path)


# ================== Helpers: parse date and disease name from instruction ==================
def extract_date(text):
    """
    Extract (year, month) from instruction.
    Example: "2009年1月鼠疫发病数" -> (2009, 1)
    """
    match = re.search(r"(\d{4})年(\d{1,2})月", text)
    if match:
        year = int(match.group(1))
        month = int(match.group(2))
        return year, month
    raise ValueError(f"Date pattern not found in instruction: {text}")


def extract_disease(instruction):
    """
    Extract disease name from instruction.
    Examples:
      - "2009年1月鼠疫发病数" -> "鼠疫"
      - "2011年10月流行性感冒死亡数" -> "流行性感冒"

    If both patterns match, "发病数" is checked first.
    """
    m = re.search(r"\d{4}年\d{1,2}月(.*?)发病数", instruction)
    if m:
        return m.group(1)
    m = re.search(r"\d{4}年\d{1,2}月(.*?)死亡数", instruction)
    if m:
        return m.group(1)
    return "Unknown"


# Add log1p-transformed outputs 
df["output"] = df["output"].astype(float)
df["output_log"] = np.log1p(df["output"])


# ================== Dataset class and collate_fn ==================
class RegressionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.targets = [x[0] for x in data] 
        self.texts = [x[1] for x in data] 
        self.encodings = tokenizer(
            self.texts, truncation=True, padding=True, max_length=max_length
        )

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["targets"] = torch.tensor(self.targets[idx], dtype=torch.float)
        item["text"] = self.texts[idx]
        return item

    def __len__(self):
        return len(self.targets)


def collate_fn(batch):
    batch_dict = {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "targets": torch.stack([item["targets"] for item in batch]),
    }
    if "text" in batch[0]:
        batch_dict["texts"] = [item["text"] for item in batch]
    return batch_dict


# ================== Model-related functions ==================
def freeze_and_configure_base_model(base_model):
    """
    Freeze all pretrained weights except:
      - token embeddings (embed_tokens)
      - the last two transformer layers: layers.30 and layers.31
    """
    for name, param in base_model.named_parameters():
        if "embed_tokens" in name or "layers.30" in name or "layers.31" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False


def get_model():
    base_model = AutoModel.from_pretrained(
        pretrained_path,
        device_map={"": 0},
        torch_dtype=torch.float16,
    )
    freeze_and_configure_base_model(base_model)

    peft_config = LoraConfig(
        r=16,
        lora_alpha=64,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "down_proj", "up_proj"
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="SEQ_CLS",
    )
    model = get_peft_model(base_model, peft_config)
    return model


class EnhancedLoRAModel(nn.Module):
    def __init__(self, disease_list):

        super().__init__()
        self.lora_model = get_model()
        hidden_size = self.lora_model.config.hidden_size

        # Year/month embeddings
        self.year_offset = 2004
        self.max_years = 40
        embed_dim = 128  # embedding dimension for temporal features
        self.year_emb = nn.Embedding(self.max_years, embed_dim)
        self.month_emb = nn.Embedding(12, embed_dim)

        # Regressor input: hidden_size + 2 * embed_dim (year + month)
        self.dropout = nn.Dropout(dropout_rate)
        self.regressor = nn.Sequential(
            nn.Linear(hidden_size + embed_dim * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, 1),
        ).to(cuda_device)

        self.mse_loss = nn.MSELoss(reduction="none")

    def forward(self, input_ids, attention_mask, targets=None, texts=None):
        outputs = self.lora_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        last_hidden = outputs.last_hidden_state
        input_mask = attention_mask.unsqueeze(-1).float()
        pooled = (last_hidden * input_mask).sum(dim=1) / input_mask.sum(dim=1).clamp(min=1e-9)

        # Add year/month embeddings only
        if texts is not None:
            years, months = [], []
            for instr in texts:
                y, m = extract_date(instr)
                years.append(y)
                months.append(m)

            years = torch.tensor(
                [y - self.year_offset for y in years],
                dtype=torch.long,
                device=cuda_device,
            )
            months = torch.tensor(
                [m - 1 for m in months],
                dtype=torch.long,
                device=cuda_device,
            )

            year_vec = self.year_emb(years)
            month_vec = self.month_emb(months)
            extra_feat = torch.cat([year_vec, month_vec], dim=-1)
            pooled = torch.cat([pooled, extra_feat], dim=-1)

        pooled = self.dropout(pooled)
        predictions = self.regressor(pooled).squeeze(-1)

        loss = None
        if targets is not None:
            alpha = 10.0
            weight = 1.0 + alpha * targets
            raw_loss = self.mse_loss(predictions, targets)
            weighted_loss = (weight * raw_loss).mean()
            loss = weighted_loss

        return {"loss": loss, "predictions": predictions}


# ================== Training & evaluation ==================
def train_epoch(model, dataloader, optimizer, grad_scaler, scheduler):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        inputs = {
            "input_ids": batch["input_ids"].to(cuda_device),
            "attention_mask": batch["attention_mask"].to(cuda_device),
            "targets": batch["targets"].to(cuda_device),
            "texts": batch["texts"],
        }
        optimizer.zero_grad()

        # mixed-precision logic
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(**inputs)
            loss = outputs["loss"]

        grad_scaler.scale(loss).backward()
        trainable_params = (p for p in model.parameters() if p.requires_grad)
        torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
        grad_scaler.step(optimizer)
        grad_scaler.update()
        scheduler.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(model, dataloader, data_scaler):
    model.eval()
    all_targets = []
    all_predictions = []
    all_texts = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            inputs = {
                "input_ids": batch["input_ids"].to(cuda_device),
                "attention_mask": batch["attention_mask"].to(cuda_device),
                "targets": batch["targets"].to(cuda_device),
                "texts": batch["texts"],
            }
            outputs = model(**inputs)
            predictions = outputs["predictions"]

            all_targets.extend(inputs["targets"].cpu().numpy())
            all_predictions.extend(predictions.detach().cpu().numpy())
            all_texts.extend(batch["texts"])

    # Inverse transform:
    # 1) inverse MinMax to recover log-space values
    # 2) expm1 to recover original scale
    all_targets_log = data_scaler.inverse_transform(np.array(all_targets).reshape(-1, 1)).flatten()
    all_predictions_log = data_scaler.inverse_transform(np.array(all_predictions).reshape(-1, 1)).flatten()
    all_predictions_log = np.maximum(all_predictions_log, 0)

    all_targets_actual = np.expm1(all_targets_log)
    all_predictions_actual = np.expm1(all_predictions_log)

    mse = mean_squared_error(all_targets_actual, all_predictions_actual)
    mae = mean_absolute_error(all_targets_actual, all_predictions_actual)

    metrics = {"mse": mse, "mae": mae}
    return metrics, all_targets_actual, all_predictions_actual, all_texts


# ================== Main ==================
def main():
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    tokenizer = AutoTokenizer.from_pretrained(pretrained_path, trust_remote_code=True)

    # Collect all disease names (even though disease embeddings are not used,
    # we keep this to construct per-disease tasks exactly as the original code)
    all_diseases = sorted(df["instruction"].apply(extract_disease).unique().tolist())

    # Detect which tasks exist in the dataset 
    all_tasks = []
    for v in ["发病数", "死亡数"]:
        if any(v in ins for ins in df["instruction"]):
            all_tasks.append(v)

    for disease in all_diseases:
        for task in all_tasks:
            print(f"\nStart training: Disease={disease} | Task={task}")

            # Filter subset for this disease + task
            sub_df = df[df["instruction"].apply(lambda x: (disease in x) and (task in x))]
            if sub_df.empty:
                print(f"No data, skip: {disease} | {task}")
                continue

            # Sort by date
            def safe_extract_date(instr):
                try:
                    return extract_date(instr)
                except Exception:
                    return (9999, 12)

            sub_df = sub_df.copy()
            sub_df["sort_key"] = sub_df["instruction"].apply(safe_extract_date)
            sub_df = sub_df.sort_values("sort_key").reset_index(drop=True)
            sub_df.drop(columns=["sort_key"], inplace=True)

            N = len(sub_df)
            if N < 5:
                print(f"Too few data points, skip: {disease} | {task}")
                continue

            # Split 60/20/20 by time order 
            train_size = int(0.6 * N)
            valid_size = int(0.2 * N)
            test_size = N - train_size - valid_size

            train_sub = sub_df.iloc[:train_size]
            valid_sub = sub_df.iloc[train_size : train_size + valid_size]
            test_sub = sub_df.iloc[train_size + valid_size :]

            # Fit scaler on train only 
            scaler = MinMaxScaler()
            scaler.fit(train_sub["output_log"].values.reshape(-1, 1))

            train_targets = scaler.transform(train_sub["output_log"].values.reshape(-1, 1)).flatten()
            valid_targets = scaler.transform(valid_sub["output_log"].values.reshape(-1, 1)).flatten()
            test_targets = scaler.transform(test_sub["output_log"].values.reshape(-1, 1)).flatten()

            # Build (target, instruction) lists 
            train_data_list = list(zip(train_targets, train_sub["instruction"]))
            valid_data_list = list(zip(valid_targets, valid_sub["instruction"]))
            test_data_list = list(zip(test_targets, test_sub["instruction"]))

            max_length = 128
            train_dataset = RegressionDataset(train_data_list, tokenizer, max_length=max_length)
            valid_dataset = RegressionDataset(valid_data_list, tokenizer, max_length=max_length)
            test_dataset = RegressionDataset(test_data_list, tokenizer, max_length=max_length)

            train_dataloader = DataLoader(
                train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=0
            )
            valid_dataloader = DataLoader(
                valid_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=0
            )
            test_dataloader = DataLoader(
                test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=0
            )

            # Train model for this disease (disease_list kept for signature compatibility)
            model = EnhancedLoRAModel(disease_list=[disease]).to(cuda_device)
            trainable_params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW(trainable_params, lr=5e-5, weight_decay=0.05)

            total_steps = len(train_dataloader) * max_epoch_num
            warmup_steps = min(200, max(1, int(0.1 * total_steps)))
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
            )

            grad_scaler = torch.amp.GradScaler("cuda")

            best_valid_mse = float("inf")
            best_model_state = None
            no_improve_epochs = 0

            # Training loop with early stopping on valid MSE 
            for epoch in range(max_epoch_num):
                print(f"Epoch {epoch + 1}/{max_epoch_num}")
                train_loss = train_epoch(model, train_dataloader, optimizer, grad_scaler, scheduler)
                valid_metrics, _, _, _ = evaluate(model, valid_dataloader, scaler)
                print(
                    f"TrainLoss: {train_loss:.4f} | "
                    f"ValidMSE: {valid_metrics['mse']:.4f} | ValidMAE: {valid_metrics['mae']:.4f}"
                )

                if valid_metrics["mse"] < best_valid_mse:
                    best_valid_mse = valid_metrics["mse"]
                    best_model_state = deepcopy(model.state_dict())
                    no_improve_epochs = 0
                else:
                    no_improve_epochs += 1
                    if no_improve_epochs >= early_stop_patience:
                        print("Early stopping triggered.")
                        break

            # Load best model state
            if best_model_state is not None:
                model.load_state_dict(best_model_state)

            # Test evaluation and save predictions to Excel 
            test_metrics, test_actual, test_pred, test_texts = evaluate(model, test_dataloader, scaler)

            test_dates = []
            for instr in test_texts:
                y, m = extract_date(instr)
                test_dates.append(datetime(y, m, 1))

            test_result_df = pd.DataFrame(
                {
                    "Date": test_dates,
                    "Instruction": test_texts,
                    "Actual": test_actual,
                    "Predicted": test_pred,
                }
            )

            results_dir = "./results"
            os.makedirs(results_dir, exist_ok=True)
            result_file = os.path.join(results_dir, f"{disease}_{task}.xlsx")
            test_result_df.to_excel(result_file, index=False)
            print(f"Test results saved to: {result_file}")
            print(f"Test metrics: MSE={test_metrics['mse']:.6f}, MAE={test_metrics['mae']:.6f}")

            # Save model and related artifacts 
            save_dir = os.path.join("./models", f"{disease}_{task}")
            os.makedirs(save_dir, exist_ok=True)

            tokenizer.save_pretrained(save_dir)
            model.lora_model.save_pretrained(os.path.join(save_dir, "lora_adapters"))
            torch.save(model.regressor.state_dict(), os.path.join(save_dir, "regressor.pth"))

            scaler_params = {"min_": scaler.min_.tolist(), "scale_": scaler.scale_.tolist()}
            with open(os.path.join(save_dir, "scaler_params.json"), "w", encoding="utf-8") as f:
                json.dump(scaler_params, f, ensure_ascii=False)

            print(f"Model saved to: {save_dir}")


if __name__ == "__main__":
    main()
