# MIMIC-IV 临床时间序列建模模板

此 notebook 提供：
- 数据加载（假设 CSV 已导出到 `data/`）
- 慢性病患者筛选、人口学与静态特征抽取
- 时序数据重采样与缺失值插补
- 基线模型（Logistic Regression, Random Forest），ARIMA 示例
- 深度学习模型（LSTM, Transformer）训练示例
- 解释性示例（SHAP）与注意力可视化提示

> **注意**：这是通用模板。请根据你实际的 `chartevents` / `labevents` 字段、ITEMID 列表和计算资源进行调整。


In [2]:
# 写入本地 MIMIC 目录说明（已将示例路径替换为你的本地路径）
import os
# 修改为你的 MIMIC 导出目录（包含 hosp/ 和 icu/ 子目录）。
DATA_DIR = '/Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1'
print('当前工作目录：', os.getcwd())
print(f'已将 DATA_DIR 默认设置为: {DATA_DIR}')
print('请确保下列子目录存在并包含 CSV：')
print(' - hosp/ 例如: hosp/admissions.csv, hosp/patients.csv, hosp/diagnoses_icd.csv, hosp/labevents.csv, hosp/prescriptions.csv')
print(' - icu/ 例如: icu/chartevents.csv, icu/d_items.csv, icu/icustays.csv')


当前工作目录： /Users/yuchenzhou/Documents/duke/compsci526/final_proj/git_proj/yuchen_testing
已将 DATA_DIR 默认设置为: /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1
请确保下列子目录存在并包含 CSV：
 - hosp/ 例如: hosp/admissions.csv, hosp/patients.csv, hosp/diagnoses_icd.csv, hosp/labevents.csv, hosp/prescriptions.csv
 - icu/ 例如: icu/chartevents.csv, icu/d_items.csv, icu/icustays.csv


In [7]:
# 快速预览：列出并显示关键 CSV 的列名和前三行（如果存在）
import os, pandas as pd
paths = {
    'admissions': os.path.join(DATA_DIR, 'hosp', 'admissions.csv'),
    'patients': os.path.join(DATA_DIR, 'hosp', 'patients.csv'),
    'diagnoses_icd': os.path.join(DATA_DIR, 'hosp', 'diagnoses_icd.csv'),
    'chartevents': os.path.join(DATA_DIR, 'icu', 'chartevents.csv'),
    'labevents': os.path.join(DATA_DIR, 'hosp', 'labevents.csv'),
}
from IPython.display import display
for name, p in paths.items():
    print('---', name, '->', p)
    if os.path.exists(p):
        try:
            df = pd.read_csv(p, nrows=5)
            print('columns:', list(df.columns))
            print('preview:')
            display(df.head(3))
        except Exception as e:
            print('读取失败:', e)
    else:
        print('NOT FOUND')


--- admissions -> /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1/hosp/admissions.csv
columns: ['subject_id', 'hadm_id', 'admittime', 'dischtime', 'deathtime', 'admission_type', 'admit_provider_id', 'admission_location', 'discharge_location', 'insurance', 'language', 'marital_status', 'race', 'edregtime', 'edouttime', 'hospital_expire_flag']
preview:


Unnamed: 0,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admit_provider_id,admission_location,discharge_location,insurance,language,marital_status,race,edregtime,edouttime,hospital_expire_flag
0,10000032,22595853,2180-05-06 22:23:00,2180-05-07 17:15:00,,URGENT,P49AFC,TRANSFER FROM HOSPITAL,HOME,Medicaid,English,WIDOWED,WHITE,2180-05-06 19:17:00,2180-05-06 23:30:00,0
1,10000032,22841357,2180-06-26 18:27:00,2180-06-27 18:49:00,,EW EMER.,P784FA,EMERGENCY ROOM,HOME,Medicaid,English,WIDOWED,WHITE,2180-06-26 15:54:00,2180-06-26 21:31:00,0
2,10000032,25742920,2180-08-05 23:44:00,2180-08-07 17:50:00,,EW EMER.,P19UTS,EMERGENCY ROOM,HOSPICE,Medicaid,English,WIDOWED,WHITE,2180-08-05 20:58:00,2180-08-06 01:44:00,0


--- patients -> /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1/hosp/patients.csv
columns: ['subject_id', 'gender', 'anchor_age', 'anchor_year', 'anchor_year_group', 'dod']
preview:


Unnamed: 0,subject_id,gender,anchor_age,anchor_year,anchor_year_group,dod
0,10000032,F,52,2180,2014 - 2016,2180-09-09
1,10000048,F,23,2126,2008 - 2010,
2,10000058,F,33,2168,2020 - 2022,


--- diagnoses_icd -> /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1/hosp/diagnoses_icd.csv
columns: ['subject_id', 'hadm_id', 'seq_num', 'icd_code', 'icd_version']
preview:


Unnamed: 0,subject_id,hadm_id,seq_num,icd_code,icd_version
0,10000032,22595853,1,5723,9
1,10000032,22595853,2,78959,9
2,10000032,22595853,3,5715,9


--- chartevents -> /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1/icu/chartevents.csv
preview:


Unnamed: 0,subject_id,hadm_id,stay_id,caregiver_id,charttime,storetime,itemid,value,valuenum,valueuom,warning
0,10000032,29079034,39553978,18704,2180-07-23 12:36:00,2180-07-23 14:45:00,226512,39.4,39.4,kg,0
1,10000032,29079034,39553978,18704,2180-07-23 12:36:00,2180-07-23 14:45:00,226707,60.0,60.0,Inch,0
2,10000032,29079034,39553978,18704,2180-07-23 12:36:00,2180-07-23 14:45:00,226730,152.0,152.0,cm,0


--- labevents -> /Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1/hosp/labevents.csv
columns: ['labevent_id', 'subject_id', 'hadm_id', 'specimen_id', 'itemid', 'order_provider_id', 'charttime', 'storetime', 'value', 'valuenum', 'valueuom', 'ref_range_lower', 'ref_range_upper', 'flag', 'priority', 'comments']
preview:


Unnamed: 0,labevent_id,subject_id,hadm_id,specimen_id,itemid,order_provider_id,charttime,storetime,value,valuenum,valueuom,ref_range_lower,ref_range_upper,flag,priority,comments
0,1,10000032,,2704548,50931,P69FQC,2180-03-23 11:51:00,2180-03-23 15:56:00,___,95.0,mg/dL,70.0,100.0,,ROUTINE,"IF FASTING, 70-100 NORMAL, >125 PROVISIONAL DI..."
1,2,10000032,,36092842,51071,P69FQC,2180-03-23 11:51:00,2180-03-23 16:00:00,NEG,,,,,,ROUTINE,
2,3,10000032,,36092842,51074,P69FQC,2180-03-23 11:51:00,2180-03-23 16:00:00,NEG,,,,,,ROUTINE,


In [8]:
# --------- 计算 30 天再入院标签 ---------
# 读取 admissions 数据并解析日期列
print('正在计算 30 天再入院标签...')

# 辅助函数：尝试解析日期列
def _parse_dates_if_exists(df, cols):
    """如果列存在，则解析为 datetime"""
    for col in cols:
        if col in df.columns:
            df[col] = pd.to_datetime(df[col], errors='coerce')
    return df

# 读取 admissions 数据
admissions_path = os.path.join(DATA_DIR, 'hosp', 'admissions.csv')
admissions = pd.read_csv(admissions_path)
admissions = _parse_dates_if_exists(admissions, ['admittime', 'dischtime', 'deathtime', 'edregtime', 'edouttime'])

# 按 subject_id 和 admittime 排序
admissions = admissions.sort_values(['subject_id', 'admittime']).reset_index(drop=True)

# 为每个入院记录计算是否有 30 天内的再入院
readmit_30d = []
for idx, row in admissions.iterrows():
    subject_id = row['subject_id']
    hadm_id = row['hadm_id']
    dischtime = row['dischtime']
    
    # 如果没有出院时间，跳过（无法计算再入院）
    if pd.isna(dischtime):
        readmit_30d.append(0)
        continue
    
    # 查找同一患者的后续入院记录
    future_admissions = admissions[
        (admissions['subject_id'] == subject_id) & 
        (admissions['hadm_id'] != hadm_id) &  # 排除当前入院
        (admissions['admittime'] > dischtime)  # 必须在出院后
    ]
    
    # 检查是否有在 30 天内的再入院
    if not future_admissions.empty:
        # 找到最早的后续入院
        next_admit = future_admissions.iloc[0]['admittime']
        days_to_readmit = (next_admit - dischtime).total_seconds() / (24 * 3600)
        
        if days_to_readmit <= 30:
            readmit_30d.append(1)
        else:
            readmit_30d.append(0)
    else:
        readmit_30d.append(0)

# 将标签添加到 admissions 数据框
admissions['readmit_30d'] = readmit_30d

# 显示标签分布
print(f'总入院次数: {len(admissions)}')
print(f'30天再入院数: {sum(readmit_30d)} ({100*sum(readmit_30d)/len(readmit_30d):.2f}%)')
print(f'非再入院数: {len(readmit_30d) - sum(readmit_30d)} ({100*(len(readmit_30d)-sum(readmit_30d))/len(readmit_30d):.2f}%)')
print('\n前 10 个入院记录的再入院标签:')
print(admissions[['subject_id', 'hadm_id', 'admittime', 'dischtime', 'readmit_30d']].head(10))

正在计算 30 天再入院标签...
总入院次数: 546028
30天再入院数: 107560 (19.70%)
非再入院数: 438468 (80.30%)

前 10 个入院记录的再入院标签:
   subject_id   hadm_id           admittime           dischtime  readmit_30d
0    10000032  22595853 2180-05-06 22:23:00 2180-05-07 17:15:00            0
1    10000032  22841357 2180-06-26 18:27:00 2180-06-27 18:49:00            1
2    10000032  29079034 2180-07-23 12:35:00 2180-07-25 17:55:00            1
3    10000032  25742920 2180-08-05 23:44:00 2180-08-07 17:50:00            0
4    10000068  25022803 2160-03-03 23:16:00 2160-03-04 06:26:00            0
5    10000084  23052089 2160-11-21 01:56:00 2160-11-25 14:52:00            0
6    10000084  29888819 2160-12-28 05:11:00 2160-12-28 16:07:00            0
7    10000108  27250926 2163-09-27 23:17:00 2163-09-28 09:04:00            0
8    10000117  22927623 2181-11-15 02:05:00 2181-11-15 14:52:00            0
9    10000117  27988844 2183-09-18 18:10:00 2183-09-21 16:30:00            0


In [3]:
# 主代码（较长）
import os
import gc
import numpy as np
import pandas as pd
from datetime import timedelta
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.impute import KNNImputer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# 深度学习依赖 (PyTorch)
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ARIMA
import statsmodels.api as sm

# SHAP
import shap



In [10]:
# --------- 设置 ---------
# 使用上面第一个 cell 中设置的 DATA_DIR。如果未定义，请在本 cell 顶部手动修改。
try:
    DATA_DIR  # noqa: F821
except NameError:
    DATA_DIR = '/Users/yuchenzhou/documents/duke/compsci526/final_proj/mimic-iv-3.1'
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)



<torch._C.Generator at 0x16955d5f0>

In [11]:
# --------- 读数据（示例） ---------
print('加载 CSV（示例）...')
# 尝试从 hosp/ 和 icu/ 子目录读取文件，优先使用 hosp/ 下的行政数据和 icu/ 下的监测数据。
hosp_dir = os.path.join(DATA_DIR, 'hosp')
icu_dir = os.path.join(DATA_DIR, 'icu')
def _read(path, **kwargs):
    if os.path.exists(path):
        return pd.read_csv(path, **kwargs)
    else:
        raise FileNotFoundError(f'未找到文件: {path}')

加载 CSV（示例）...


In [13]:
# --------- 读取患者和诊断数据 ---------
# 注意：admissions 数据已在上一个 cell 中加载并包含 readmit_30d 标签
patients = _read(os.path.join(hosp_dir, 'patients.csv'))
diagnoses = _read(os.path.join(hosp_dir, 'diagnoses_icd.csv'))

# 读取时序数据（chartevents 和 labevents）
# 注意：这些文件可能很大，实际使用时考虑分块读取或筛选
print('正在加载时序数据（可能需要几分钟）...')
chartevents = _read(os.path.join(icu_dir, 'chartevents.csv'))
chartevents = _parse_dates_if_exists(chartevents, ['charttime', 'storetime'])
labevents = _read(os.path.join(hosp_dir, 'labevents.csv'))
labevents = _parse_dates_if_exists(labevents, ['charttime', 'storetime'])
print(f'chartevents: {len(chartevents)} 行')
print(f'labevents: {len(labevents)} 行')

# --------- 定义慢性病筛选 ICD 列表（示例：糖尿病、心衰、COPD、CKD） ---------
chronic_icd_list = {
    'diabetes': ['250.00', '250.01', 'E11'],
    'heart_failure': ['428.0','I50'],
    'copd': ['496','J44'],
    'ckd': ['585','N18']
}

def patient_has_chronic(subject_id):
    rows = diagnoses[diagnoses['subject_id']==subject_id]
    if rows.empty:
        return False
    icd_codes = rows['icd_code'].astype(str).str.upper().tolist()
    for name, code_list in chronic_icd_list.items():
        for c in code_list:
            for icd in icd_codes:
                if c in icd:
                    return True
    return False

# 抽样筛选（示例）
unique_subjects = admissions['subject_id'].unique()[:5000]
selected_subjects = [s for s in unique_subjects if patient_has_chronic(s)]
print(f'筛到慢性病患者: {len(selected_subjects)} (示例抽样)')

# 静态特征
adm = admissions[admissions['subject_id'].isin(selected_subjects)].copy()
pat = patients[patients['subject_id'].isin(selected_subjects)].copy()
df = adm.merge(pat, on='subject_id', how='left')

# 计算年龄：MIMIC-IV 使用 anchor_age 而不是出生日期
# anchor_age 是患者在 anchor_year 时的年龄
# 计算入院时的年龄：anchor_age + (admittime 年份 - anchor_year)
df['admit_year'] = df['admittime'].dt.year
df['age'] = df['anchor_age'] + (df['admit_year'] - df['anchor_year'])
df['age'] = df['age'].clip(lower=0)  # 确保年龄非负

df['gender'] = df['gender'].map({'M':1,'F':0}).fillna(0).astype(int)
df['los_days'] = (df['dischtime'] - df['admittime']).dt.total_seconds()/(3600*24)

# 使用 readmit_30d 标签而不是 dead_in_hospital
static_features = ['subject_id', 'hadm_id', 'age', 'gender', 'los_days', 'readmit_30d']
static_df = df[static_features].drop_duplicates().reset_index(drop=True)
print(static_df.head())
print(f'\n标签分布 - readmit_30d: {static_df["readmit_30d"].value_counts().to_dict()}')

# 时序变量的 ITEMID 示例（请用真实 ITEMID 表替换）
vital_items_of_interest = {
    'heart_rate': [211, 220045],
    'sys_bp': [220179, 51],
    'dias_bp': [220180, 8368],
    'resp_rate': [220210, 618],
    'temp': [223761, 678],
    'spo2': [220277]
}
lab_items_of_interest = {
    'creatinine': [50912],
    'glucose': [807, 823],
    'wbc': [730],
}

def get_patient_events(hadm_id, window_hours=72, resample_freq='1H'):
    adm_row = adm[adm['hadm_id']==hadm_id].iloc[0]
    t0 = adm_row['admittime']
    t_end = t0 + pd.Timedelta(hours=window_hours)
    ce = chartevents[(chartevents['hadm_id']==hadm_id) & (chartevents['charttime']>=t0) & (chartevents['charttime']<=t_end)]
    le = labevents[(labevents['hadm_id']==hadm_id) & (labevents['charttime']>=t0) & (labevents['charttime']<=t_end)]
    rows = []
    for var, itemids in vital_items_of_interest.items():
        tmp = ce[ce['itemid'].isin(itemids)][['charttime','value']].copy()
        if tmp.empty:
            continue
        tmp = tmp.rename(columns={'charttime':'time','value':var})
        rows.append(tmp.set_index('time')[var])
    for var, itemids in lab_items_of_interest.items():
        tmp = le[le['itemid'].isin(itemids)][['charttime','value']].copy()
        if tmp.empty:
            continue
        tmp = tmp.rename(columns={'charttime':'time','value':var})
        rows.append(tmp.set_index('time')[var])
    if not rows:
        return None
    combined = pd.concat(rows, axis=1)
    combined = combined.resample(resample_freq).mean()
    combined = combined[:t0 + pd.Timedelta(hours=window_hours)]
    return combined

MAX_HOURS = 72
RESAMPLE_FREQ = '1H'
TIMESTEPS = int(MAX_HOURS)

def build_dataset(hadm_ids):
    X_static = []
    X_ts = []
    y = []
    for hid in hadm_ids:
        s = static_df[static_df['hadm_id']==hid]
        if s.empty:
            continue
        ts = get_patient_events(hid, window_hours=MAX_HOURS, resample_freq=RESAMPLE_FREQ)
        if ts is None:
            continue
        all_cols = list(vital_items_of_interest.keys()) + list(lab_items_of_interest.keys())
        for c in all_cols:
            if c not in ts.columns:
                ts[c] = np.nan
        ts = ts[all_cols]
        if len(ts) < TIMESTEPS:
            pad_len = TIMESTEPS - len(ts)
            pad_df = pd.DataFrame(np.nan, index=pd.date_range(ts.index[-1]+pd.Timedelta(hours=1), periods=pad_len, freq=RESAMPLE_FREQ), columns=ts.columns)
            ts = pd.concat([ts, pad_df])
        else:
            ts = ts.iloc[:TIMESTEPS]
        X_static.append(s[['age','gender','los_days']].iloc[0].values.astype(float))
        X_ts.append(ts.values.astype(float))
        y.append(int(s['readmit_30d'].iloc[0]))  # 使用 30 天再入院标签
    X_static = np.array(X_static)
    X_ts = np.array(X_ts)
    y = np.array(y)
    return X_static, X_ts, y

hadm_ids = static_df['hadm_id'].unique()[:500]
X_static, X_ts, y = build_dataset(hadm_ids)
print('Shape:', X_static.shape, X_ts.shape, y.shape)

def impute_time_series_array(X_ts):
    N, T, F = X_ts.shape
    X_imputed = X_ts.copy()
    for i in range(N):
        df_ts = pd.DataFrame(X_ts[i], columns=[f'f{j}' for j in range(F)])
        df_ts = df_ts.interpolate(method='linear', limit_direction='both', axis=0).ffill().bfill()
        df_ts = df_ts.fillna(df_ts.mean())
        X_imputed[i] = df_ts.values
    return X_imputed

X_ts_imputed = impute_time_series_array(X_ts)
scaler = StandardScaler()
X_static_scaled = scaler.fit_transform(X_static)

def pool_time_series_features(X_ts):
    N, T, F = X_ts.shape
    feats = []
    for i in range(N):
        arr = X_ts[i]
        mean = np.nanmean(arr, axis=0)
        std = np.nanstd(arr, axis=0)
        mn = np.nanmin(arr, axis=0)
        mx = np.nanmax(arr, axis=0)
        feats.append(np.concatenate([mean,std,mn,mx]))
    return np.array(feats)

X_pool = pool_time_series_features(X_ts_imputed)
X_final = np.hstack([X_static_scaled, X_pool])

X_train, X_test, y_train, y_test = train_test_split(X_final, y, test_size=0.2, random_state=RANDOM_SEED, stratify=y)

lr = LogisticRegression(max_iter=1000)
lr.fit(X_train, y_train)
y_pred_proba = lr.predict_proba(X_test)[:,1]
print('LogReg AUC:', roc_auc_score(y_test, y_pred_proba))

rf = RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED, n_jobs=4)
rf.fit(X_train, y_train)
y_pred_proba_rf = rf.predict_proba(X_test)[:,1]
print('RandomForest AUC:', roc_auc_score(y_test, y_pred_proba_rf))



正在加载时序数据（可能需要几分钟）...


ParserError: Error tokenizing data. C error: Expected 11 fields in line 163446604, saw 21


In [12]:
# --------- 模型训练与评估（分层 CV） ---------
print('开始模型训练与评估（分层 CV）...')
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
import joblib

# 使用 X_final, y 已由上文构造
assert 'X_final' in globals() and 'y' in globals(), '请先运行前面的数据构造单元以生成 X_final 和 y'

# 简单的 5 折分层交叉验证评估（LogReg + RandomForest）
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
models = {
    'logreg': LogisticRegression(max_iter=1000, class_weight='balanced'),
    'rf': RandomForestClassifier(n_estimators=200, random_state=RANDOM_SEED, n_jobs=4, class_weight='balanced')
}
cv_results = {m: {'auroc': [], 'auprc': []} for m in models}

for fold, (train_idx, val_idx) in enumerate(skf.split(X_final, y)):
    print(f'Fold {fold+1}/5')
    Xtr, Xval = X_final[train_idx], X_final[val_idx]
    ytr, yval = y[train_idx], y[val_idx]

    for name, model in models.items():
        model.fit(Xtr, ytr)
        probs = model.predict_proba(Xval)[:,1]
        auroc = roc_auc_score(yval, probs) if len(set(yval))>1 else 0.5
        auprc = average_precision_score(yval, probs) if len(set(yval))>1 else 0.0
        cv_results[name]['auroc'].append(auroc)
        cv_results[name]['auprc'].append(auprc)
        print(f'  {name} AUROC={auroc:.4f} AUPRC={auprc:.4f}')

print('\nCV Summary:')
for name in models:
    print(f"{name}: AUROC mean={np.mean(cv_results[name]['auroc']):.4f} std={np.std(cv_results[name]['auroc']):.4f} | AUPRC mean={np.mean(cv_results[name]['auprc']):.4f} std={np.std(cv_results[name]['auprc']):.4f}")

# 训练最终模型在全部训练数据上并保存
print('\n训练最终模型并保存...')
for name, model in models.items():
    model.fit(X_final, y)
    joblib.dump(model, f'{name}_readmit30_model.pkl')
    print(f'Saved {name}_readmit30_model.pkl')

# 保存 scaler（如果存在）
if 'scaler' in globals():
    joblib.dump(scaler, 'static_scaler.pkl')
    print('Saved static_scaler.pkl')

print('训练与评估完成。')

开始模型训练与评估（分层 CV）...


AssertionError: 请先运行前面的数据构造单元以生成 X_final 和 y

In [None]:
# ARIMA 示例略（请在单变量非平稳/平稳检测后运行）

class TimeSeriesDataset(Dataset):
    def __init__(self, X_static, X_ts, y):
        self.X_static = torch.tensor(X_static, dtype=torch.float32)
        self.X_ts = torch.tensor(X_ts, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X_static[idx], self.X_ts[idx], self.y[idx]

class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=64, num_layers=1, static_size=3, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size + static_size, 1)
    def forward(self, x_static, x_ts):
        out, (hn, cn) = self.lstm(x_ts)
        h_last = out[:, -1, :]
        feats = torch.cat([h_last, x_static], dim=1)
        feats = self.dropout(feats)
        logits = self.fc(feats)
        return torch.sigmoid(logits).squeeze(1)

class TransformerClassifier(nn.Module):
    def __init__(self, input_size, d_model=64, nhead=4, num_layers=2, static_size=3, dim_feedforward=128, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(input_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(d_model + static_size, 1)
    def forward(self, x_static, x_ts):
        x = self.input_proj(x_ts)
        x = self.transformer(x)
        x_pooled = x.mean(dim=1)
        feats = torch.cat([x_pooled, x_static], dim=1)
        logits = self.fc(feats)
        return torch.sigmoid(logits).squeeze(1)

def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3, device='cpu'):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()
    best_auc = 0.0
    for epoch in range(epochs):
        model.train()
        train_losses = []
        for xs, xt, yb in train_loader:
            xs = xs.to(device); xt = xt.to(device); yb = yb.to(device)
            opt.zero_grad()
            ypred = model(xs, xt)
            loss = loss_fn(ypred, yb)
            loss.backward()
            opt.step()
            train_losses.append(loss.item())
        model.eval()
        ys, preds = [], []
        with torch.no_grad():
            for xs, xt, yb in val_loader:
                xs = xs.to(device); xt = xt.to(device)
                p = model(xs, xt).cpu().numpy()
                preds.extend(p.tolist())
                ys.extend(yb.numpy().tolist())
        auc = roc_auc_score(ys, preds) if len(set(ys))>1 else 0.5
        print(f'Epoch {epoch+1}/{epochs} train_loss={np.mean(train_losses):.4f} val_auc={auc:.4f}')
        if auc > best_auc:
            best_auc = auc
            torch.save(model.state_dict(), 'best_model.pt')
    return best_auc

Xts = X_ts_imputed
Xst = X_static_scaled[:Xts.shape[0]]
y_small = y[:Xts.shape[0]]
Xtr_s, Xval_s, Xtr_ts, Xval_ts, ytr, yval = train_test_split(Xst, Xts, y_small, test_size=0.2, stratify=y_small, random_state=RANDOM_SEED)
train_ds = TimeSeriesDataset(Xtr_s, Xtr_ts, ytr)
val_ds = TimeSeriesDataset(Xval_s, Xval_ts, yval)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

input_size = Xts.shape[2]
lstm_model = LSTMClassifier(input_size=input_size, hidden_size=64, static_size=Xst.shape[1])
print('训练 LSTM 模型（示例）...')
train_model(lstm_model, train_loader, val_loader, epochs=5, lr=1e-3, device='cpu')

transformer_model = TransformerClassifier(input_size=input_size, d_model=64, static_size=Xst.shape[1])
print('训练 Transformer 模型（示例）...')
train_model(transformer_model, train_loader, val_loader, epochs=5, lr=1e-3, device='cpu')

print('用 SHAP 分析 RandomForest 特征重要性（示例）')
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[1], X_test, feature_names=None, show=False)
plt.title('SHAP Summary: RandomForest')
plt.savefig('shap_summary_rf.png')
plt.close()

import joblib
joblib.dump(scaler, 'static_scaler.pkl')
joblib.dump(rf, 'rf_model.pkl')
joblib.dump(lr, 'lr_model.pkl')
print('完成。模型与图像已保存到当前目录。')
