# Step1. 导入数据集

In [1]:
import pandas as pd  # 先导入 pandas

df = pd.read_stata(r"..\data\raw\h_elsa_g3.dta")
print(df.head())

   idauniq idauniqc  pn pnc  hh1hhid  hh2hhid  hh3hhid  hh4hhid  hh5hhid  \
0   100001   100001   4   4  14458.0  11870.0  15553.0  17241.0  11052.0   
1   100005   100005   2   2  13012.0      NaN      NaN      NaN  13327.0   
2   100006   100006   3   3  12415.0  12608.0  13496.0  14156.0  14253.0   
3   100007   100007   4   4  15591.0  12585.0  14930.0  16165.0  16347.0   
4   100009   100009   3   3  15798.0  15014.0  13922.0  11245.0  10024.0   

   hh6hhid  ...  s6slfpos2a  s7slfpos2a  r7slfpos2 s7slfpos2 r6slfneg2a  \
0      NaN  ...         NaN         NaN        NaN       NaN        NaN   
1      NaN  ...         NaN         NaN        NaN       NaN        NaN   
2  10656.0  ...         NaN         NaN        NaN       NaN        NaN   
3  15127.0  ...         6.0         NaN        NaN       NaN        NaN   
4  11311.0  ...         4.0         3.0        NaN       NaN        NaN   

  r7slfneg2a s6slfneg2a s7slfneg2a r7slfneg2 s7slfneg2  
0        NaN        NaN        NaN 

# Step2. 筛选变量
由于elsa是个很广泛的数据集，我们专注AD早筛的话，很多变量都是不太相关的，可以直接去掉。于是我们采用逐渐缩小范围的方式筛选数据
## 1. 筛除不相关变量
第一次筛选，留下b,k,l,m,q中的本人相关变量（配偶的就去掉了），以及A中，性别、年龄、婚姻、教育相关的变量

In [52]:
import sys
import os

from temp import count_unique_vars

# 获取所有列名列表
cols = list(df.columns)

#sectionA:demographics_marriage
cols_marriage = cols[cols.index("r1mstat"):cols.index("r9mstat")+1]

# sectionB:health
cols_health = cols[cols.index("r1shlt"):cols.index("s9memrys")+1]
cols_health = [col for col in cols_health if not col.startswith('s')]

# sectionK:physical measures
cols_pm = cols[cols.index("r1wspeed1"):cols.index("s6chrothr")+1]
cols_pm = [col for col in cols_pm if not col.startswith('s')]

# sectionL: assistance and caregiving   
cols_ac = cols[cols.index("r6dresshlp"):cols.index("s9gcaresckhpw")+1]
cols_ac = [col for col in cols_ac if not col.startswith('s')]

# sectionM:stress
cols_stress = cols[cols.index("r2satjob"):cols.index("s5dcother")+1]
cols_stress = [col for col in cols_stress if not col.startswith('s')]

# sectionQ: psychosocial
cols_psycho = cols[cols.index("r1depres"):cols.index("s7slfneg2")+1]
cols_psycho = [col for col in cols_psycho if not col.startswith('s')]

all_cols1 = ["ragender", "rabyear", "raeducl"] + cols_marriage + cols_health + cols_pm + cols_ac + cols_stress + cols_psycho

var_names, var_count = count_unique_vars(all_cols1)
print("变量种类数：", var_count)

变量种类数： 779


## 2. 手动删除重复变量
针对前面的df_filtered1,手动删除重复信息
1. 数据泄露问题，比如是否诊断阿兹海默



In [53]:
from temp import remove_multiple_ranges

all_cols2 = remove_multiple_ranges(all_cols1, [("r1alzhe","radiagdemen"),("r2alzhs","r9memrys")])

var_names, var_count = count_unique_vars(all_cols2)
print("变量种类数：", var_count)

变量种类数： 768


2. 不相关变量去除

In [54]:
from temp import remove_multiple_ranges

#删除 sectionQ——社会阶级地位
all_cols2 = remove_multiple_ranges(all_cols2, [("r1cantril","r3cantrilc")])

var_names, var_count = count_unique_vars(all_cols2)
print("变量种类数：", var_count)

变量种类数： 766


2. 单个问题&汇总得分之间的重复。可以抛弃单个，只取汇总得分。

In [55]:
from temp import remove_multiple_ranges_with_keep, remove_multiple_ranges


# 删除section B中的单个问题,留下RwADLTOT6/RwIADLTOT1M_E/RwNAGI10
all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1walkra", "r9nagi8a")],
    [("r1adltot6", "r9adltot6"), ("r1iadltot1m_e", "r9iadltot1m_e"), ("r1nagi10", "r9nagi10")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("adls处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1fall", "r9hip")],
    [("r1fallnum", "r9fallnum")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("falls处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r1painfr", "r9painfr")],
)
var_names, var_count = count_unique_vars(all_cols2)
print("painfr处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r2shltc","r9hips")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("change处理后变量种类数：", var_count)

# 删除section K中的单个测量值，留下平均值
all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r1wspeed1","r9wspeed2"), ("r1walksft","r9walkothr")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("walk处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [
        ("r2systo1", "r8systo3"),
        ("r2diasto1", "r8diasto3"),
        ("r2pulse1", "r8pulse3"),
        ("r2bpsft","r8bpothr")
    ]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Blood Pressure and Heart Rate Measurements处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [
        ("r2lgrip1", "r8lgrip3"),
        ("r2rgrip1", "r8rgrip3"),
        ("r2gripsft","r8gripothr")
    ]
)
var_names, var_count = count_unique_vars(all_cols2)
print("hand grip处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [
        ("r2hghtsft", "r9wghtothr"),
    ]
)
var_names, var_count = count_unique_vars(all_cols2)
print("height&weight处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [
        ("r2puff1", "r4puff3"),
        ("r2fvc1", "r4fvc3"),
        ("r2fev1", "r4fev3"),
        ("r2puffsft","r6puffothr_e")
    ]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Lung Function Measurements处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r2sbstan", "r6balothr")],
    [("r2balance_e", "r6balance_e")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Balance Tests处理后变量种类数：", var_count)


all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r2legrsft", "r6legrothr")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Leg Raise Tests处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r2chrsft", "r6chrothr")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Chair Stand Tests处理后变量种类数：", var_count)

# 删除section L——Care for ADLs or IADLS: Receives Any Care中的细分项
all_cols2 = remove_multiple_ranges(
    all_cols2,
    [("r6racany", "r9rcany")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Care for ADLs or IADLS: Receives Any Care处理后变量种类数：", var_count)

# 删除section M中的单项，留下汇总分
all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1sustdfe", "r9ssupportm")],
    [("r1ssupport6","r9ssupport6")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Social Support: Spouse处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1kustdfe", "r9ksupportm")],
    [("r1ksupport6","r9ksupport6")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Social Support: children处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1oustdfe", "r9osupportm")],
    [("r1osupport6","r9osupport6")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Social Support: Other Family Members处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1fustdfe", "r9fsupportm")],
    [("r1fsupport6","r9fsupport6")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Social Support: friends处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1depres", "r9cesdm")],
    [("r1cesd","r9cesd")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Depressive Symptoms: CESD处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r2lideal", "r9satlifez")],
    [("r2lsatsc", "r9lsatsc")]
)
var_names, var_count = count_unique_vars(all_cols2)
print("Satisfaction with Life Scale处理后变量种类数：", var_count)

all_cols2 = remove_multiple_ranges_with_keep(
    all_cols2,
    [("r1ageprv", "r9casp12")],
    [
        ("r1cntrlndx6", "r9cntrlndx6"),   # Control (6 items)
        ("r1autondx5",  "r9autondx5"),    # Autonomy (5 items)
        ("r1plsrndx4",  "r9plsrndx4"),   # Pleasure (4 items)
        ("r1slfrlndx4", "r9slfrlndx4"),   # Self-realization (4 items)
        ("r1casp19",    "r9casp19")       # CASP-19 total score
    ]
)
var_names, var_count = count_unique_vars(all_cols2)
print("CASP处理后变量种类数：", var_count)

adls处理后变量种类数： 690
falls处理后变量种类数： 684
painfr处理后变量种类数： 683
change处理后变量种类数： 660
walk处理后变量种类数： 653
Blood Pressure and Heart Rate Measurements处理后变量种类数： 639
hand grip处理后变量种类数： 629
height&weight处理后变量种类数： 619
Lung Function Measurements处理后变量种类数： 601
Balance Tests处理后变量种类数： 588
Leg Raise Tests处理后变量种类数： 584
Chair Stand Tests处理后变量种类数： 579
Care for ADLs or IADLS: Receives Any Care处理后变量种类数： 576
Social Support: Spouse处理后变量种类数： 566
Social Support: children处理后变量种类数： 556
Social Support: Other Family Members处理后变量种类数： 546
Social Support: friends处理后变量种类数： 536
Depressive Symptoms: CESD处理后变量种类数： 527
Satisfaction with Life Scale处理后变量种类数： 512
CASP处理后变量种类数： 488


# 3. 手动筛掉wave太少的变量
统计488个变量中，wave-变量的分布情况


In [56]:
# 重新导入模块以获取新添加的函数
import importlib
import temp as data_process_utils
importlib.reload(data_process_utils)
from temp import count_wave_numbers

wave_counts = count_wave_numbers(all_cols2)
print("各波次变量数量统计：", wave_counts)
# 计算各波次缺失变量数量
total_vars = 488  # 总变量数
missing_counts = {}
for wave, count in wave_counts.items():
    missing_counts[wave] = total_vars - count
print("各波次缺失变量数量统计：", missing_counts)

# 将一维变量列表转换为二维DataFrame
import pandas as pd
import re

def create_wave_matrix(col_list):
    """
    将变量列表转换为二维矩阵
    行：变量名（去掉波次前缀）
    列：波次（r1-r9）
    值：1表示该波次有该变量，NaN表示没有
    注意：
    - 保留有波次的变量（r+数字+变量名格式）
    - 不符合格式的变量，另外收集起来
    """
    var_names = set()
    wave_var_dict = {}
    non_wave_vars = []   # 保存不符合格式的变量
    
    for col in col_list:
        match = re.match(r'r(\d+)([a-zA-Z0-9_]+)', col.lower())
        if match:
            wave = int(match.group(1))
            var_name = match.group(2)
            var_names.add(var_name)
            if var_name not in wave_var_dict:
                wave_var_dict[var_name] = set()
            wave_var_dict[var_name].add(wave)
        else:
            non_wave_vars.append(col)
    
    waves = list(range(1, 10))  # r1到r9
    var_names_sorted = sorted(var_names)
    
    matrix_data = {}
    for wave in waves:
        matrix_data[f'r{wave}'] = [
            1 if wave in wave_var_dict.get(var, set()) else None
            for var in var_names_sorted
        ]
    
    df_matrix = pd.DataFrame(matrix_data, index=var_names_sorted)
    return df_matrix, non_wave_vars

# 创建波次变量矩阵 & 非波次变量列表
wave_matrix, non_wave_vars = create_wave_matrix(all_cols2)
print(f"\n变量矩阵形状: {wave_matrix.shape}")
print(f"非波次变量数量: {len(non_wave_vars)}")

# 统计每个变量在多少个波次中出现
var_wave_counts = wave_matrix.count(axis=1)
print(f"\n变量在波次中的分布统计:")
print(f"出现在所有9个波次的变量数: {sum(var_wave_counts == 9)}")
print(f"出现在8个波次的变量数: {sum(var_wave_counts == 8)}")

各波次变量数量统计： {1: 126, 2: 234, 3: 162, 4: 205, 5: 185, 6: 290, 7: 298, 8: 201, 9: 201}
各波次缺失变量数量统计： {1: 362, 2: 254, 3: 326, 4: 283, 5: 303, 6: 198, 7: 190, 8: 287, 9: 287}

变量矩阵形状: (440, 9)
非波次变量数量: 48

变量在波次中的分布统计:
出现在所有9个波次的变量数: 98
出现在8个波次的变量数: 39


从all_cols2中去掉出现在7个及以下wave里的变量。

In [57]:
# 从all_cols2中去掉出现在7个及以下wave里的变量，只保留出现在8个或9个wave的变量

import re
import pandas as pd

def filter_vars_by_wave_count(col_list, min_waves=8):
    """
    过滤变量列表，只保留出现在min_waves个或更多波次中的变量

    Args:
        col_list: 变量列表
        min_waves: 最少出现的波次数（默认8）

    Returns:
        filtered_cols: 过滤后的变量列表
        removed_cols: 被移除的变量列表
    """
    # 分离有波次的变量和无波次的变量
    wave_vars = {}  # {变量名: [出现的波次]}
    non_wave_vars = []

    for col in col_list:
        match = re.match(r'r(\d+)([a-zA-Z0-9_]+)', col.lower())
        if match:
            wave = int(match.group(1))
            var_name = match.group(2)
            if var_name not in wave_vars:
                wave_vars[var_name] = []
            wave_vars[var_name].append(wave)
        else:
            non_wave_vars.append(col)

    # 统计每个变量出现在多少个波次中
    var_wave_counts = {var: len(waves) for var, waves in wave_vars.items()}

    # 筛选出现在min_waves个或更多波次的变量
    valid_vars = [var for var, count in var_wave_counts.items() if count >= min_waves]
    removed_vars = [var for var, count in var_wave_counts.items() if count < min_waves]

    # 重建符合条件的变量列表
    filtered_cols = []
    for var in valid_vars:
        for wave in wave_vars[var]:
            filtered_cols.append(f"r{wave}{var}")

    # 添加非波次变量（这些变量保留）
    filtered_cols.extend(non_wave_vars)

    # 保持原始顺序
    filtered_cols = [col for col in col_list if col in filtered_cols]

    removed_cols = [col for col in col_list if col not in filtered_cols]

    print(f"波次过滤统计:")
    print(f"- 原始变量总数: {len(col_list)}")
    print(f"- 有波次的变量数: {len(wave_vars)}")
    print(f"- 非波次变量数: {len(non_wave_vars)}")
    print(f"- 出现在{min_waves}个或更多波次的变量数:{len(valid_vars)}")
    print(f"- 被移除的变量数: {len(removed_vars)}")
    print(f"- 过滤后变量总数: {len(filtered_cols)}")

    if removed_vars:
        print(f"\n被移除的变量（出现波次<{min_waves}）:")
        removed_stats = {var: var_wave_counts[var] for var in removed_vars[:10]}  # 只显示前10个
        for var, count in removed_stats.items():
            print(f"  {var}: {count}个波次")
        if len(removed_vars) > 10:
            print(f"  ... 还有{len(removed_vars)-10}个变量被移除")      

    return filtered_cols, removed_cols

# 应用过滤器，只保留出现在8个或更多波次的变量
all_cols2_filtered, removed_cols = filter_vars_by_wave_count(all_cols2, min_waves=8)

# 更新all_cols2
all_cols2 = all_cols2_filtered

# 重新统计变量种类数
from temp import count_unique_vars
var_names, var_count = count_unique_vars(all_cols2)
print(f"\n过滤后变量种类数: {var_count}")


print(f"\n过滤完成！all_cols2已更新为只包含出现在8个或更多波次的变量。")

波次过滤统计:
- 原始变量总数: 1950
- 有波次的变量数: 440
- 非波次变量数: 48
- 出现在8个或更多波次的变量数:137
- 被移除的变量数: 303
- 过滤后变量总数: 1242

被移除的变量（出现波次<8）:
  shltf: 1个波次
  shlta: 2个波次
  shltaf: 1个波次
  rxhrtat: 2个波次
  rxosteo: 6个波次
  rxdepres: 3个波次
  trdepres: 3个波次
  trhchol: 3个波次
  rxhchol: 7个波次
  hipr: 2个波次
  ... 还有293个变量被移除

过滤后变量种类数: 185

过滤完成！all_cols2已更新为只包含出现在8个或更多波次的变量。


## 4. 补充处理
对于一些表面缺失、实际上可以修复的变量，采用手动处理的方式

In [58]:
# 在 r2shlt 和 r4shlt 之间插入 r3shlta
try:
    r2shlt_index = all_cols2.index("r2shlt")
    # 在 r2shlt 后面插入 r3shlta
    all_cols2.insert(r2shlt_index + 1, "r3shlta")
    print(f"成功在位置 {r2shlt_index + 1} 插入 r3shlta")
    
except ValueError:
    print("错误: 在 all_cols2 中未找到 r2shlt")
except NameError:
    print("错误: all_cols2 未定义，请先运行前面的数据处理步骤")

# 插入sectionK中的变量： 血压 心率 握力 身高体重 BMI
to_insert = ["SYSTO", "DIASTO", "PULSE", "GRIPSUM", "MHEIGHT", "MWEIGHT", "MBMI"]
waves = [2, 4, 6, 8]

# 按每个变量依次扩展波次并变小写
expanded_vars = [f"r{w}{v}".lower() for v in to_insert for w in waves]

# 找到 r9walkre 的索引
idx = all_cols2.index("r9walkre")

# 插入到 r9walkre 后面
all_cols2 = all_cols2[:idx+1] + expanded_vars + all_cols2[idx+1:]

# 查看插入后的部分
print(all_cols2[idx:idx+20])


成功在位置 14 插入 r3shlta
['r9walkre', 'r2systo', 'r4systo', 'r6systo', 'r8systo', 'r2diasto', 'r4diasto', 'r6diasto', 'r8diasto', 'r2pulse', 'r4pulse', 'r6pulse', 'r8pulse', 'r2gripsum', 'r4gripsum', 'r6gripsum', 'r8gripsum', 'r2mheight', 'r4mheight', 'r6mheight']


# Step3. 计算出发病的lable
根据挑选出的认知变量计算总分，如果在3个变量里，有缺失＞30%，则该样本的这一栏分数为空

In [None]:
cols_cognition = cols[cols.index("r1tr20"):cols.index("r9tr20")+1] + cols[cols.index("r1orient"):cols.index("r9orient")+1] + cols[cols.index("r1verbf"): cols.index("r9verbf")+1]
cols_diagnosed = cols[cols.index("r1memrye"):cols.index("r9memrye")+1]
df_cognition = df[cols_cognition + cols_diagnosed]


# 1. 计算认知分数baseline（baseline取wave1）
def compute_baseline_stats(df_baseline):
    results = []

    # 去掉前缀 r1
    domain_names = [col[2:] if col.startswith('r1') else col for col in df_baseline.columns]

    # 计算每个领域均值和标准差
    for col, domain in zip(df_baseline.columns, domain_names):
        mean_val = df_baseline[col].mean()
        std_val = df_baseline[col].std()
        results.append({'domain': domain, 'mean': mean_val, 'std': std_val})

    # 计算每个领域 z 分数
    z_df = (df_baseline - df_baseline.mean()) / df_baseline.std()
    df_baseline['global'] = z_df.mean(axis=1)

    # 计算全局分数均值和标准差
    global_mean = df_baseline['global'].mean()
    global_std = df_baseline['global'].std()
    results.append({'domain': 'global', 'mean': global_mean, 'std': global_std})

    # 构建结果 DataFrame
    result_df = pd.DataFrame(results).set_index('domain')

    return result_df
baseline = compute_baseline_stats(df_cognition[["r1tr20", "r1orient","r1verbf"]])

# 2. 针对每个wave计算发病lable
# （是否患病：racogimp_label cont:-1（从不患病）/1-9（开始患病wave）/NaN（由于信息缺失无效）; 
# 当前wave距离患病wave的时间：r[wave]cogimpt cont：-1（无效，比如从未发病，或者当前wave是onset之后，yearcont为负数）/year cont（≥0）/NaN（（由于conimp缺失无效）））

# 2.1 根据是否被诊断出认知问题判断

# 2.2 计算认知分数并和baseline比较判断
import numpy as np

def compute_cogimp_labels(df_compute, baseline_stats, thred=-1.5, wave_interval=2):
    """
    计算各 wave 的认知障碍标签和发病信息。
    
    参数：
        df_compute: DataFrame，包含各 wave 认知领域和诊断列
        baseline_stats: DataFrame，index=领域名, columns=['mean','std']，wave1 baseline
        thred: float，全局 z 阈值，低于视为认知障碍
        wave_interval: int，每两个 wave 间隔年数
    
    返回：
        result_df: DataFrame，每行受访者
            - racogimp_label: 1–9 onset wave, -1 从不发病, NaN 信息缺失
            - r[wave]cogimpt: 当前 wave 到 onset wave 的年份（>=0），无效设 -1
    """
    result_df = pd.DataFrame(index=df_compute.index)
    global_z = pd.DataFrame(index=df_compute.index)
    
    # 遍历 wave1–wave8，计算出9个wave的全局 z
    for wave in range(1, 9):
        # 找出该 wave 的领域列
        memory_var = f'r{wave}tr20'
        orientation_var = f'r{wave}orient'
        executive_var = f'r{wave}verbf'

        all_vars = {'tr20': memory_var, 'orient': orientation_var, 'verbf': executive_var}
        
        # 计算每个领域的 z 分数
        domain_z = pd.DataFrame(index=df_compute.index)
        for domain, var in all_vars.items():
            if var:
                mean_val = baseline_stats.loc[domain, 'mean']
                std_val = baseline_stats.loc[domain, 'std']
                domain_z[domain] = (df_compute[var] - mean_val) / std_val
        
        # 全局 z
        global_z[wave] = domain_z.mean(axis=1, skipna=True, min_count=2)
        global_z[wave] = (global_z[wave] - baseline_stats.loc["global", 'mean'])/baseline_stats.loc["global", 'std']

    # 生成 racogimp_label
    def compute_racogimp_label(global_z, df_compute, threshold=1.5, missing_ratio=0.3):
        def label_row(row):
            # 1. 缺失值占比
            if row.isna().mean() > missing_ratio:
                return np.nan
            # 2. 找第一个满足条件的 wave
            for wave in row.index:
                diag_val = df_compute.loc[row.name, f'r{wave}memrye']
                if row[wave] < threshold or diag_val == "1.yes":
                    return int(wave)  # 返回 wave 编号
            # 3. 找不到则返回 -1
            return -1
        
        return global_z.apply(label_row, axis=1)
    racogimp_label = compute_racogimp_label(global_z, df_compute)

    
    # 生成 r[wave]cogimpt
    for wave in range(2, 9):
        yearcont = (onset_wave - wave) * wave_interval
        # 无效：从未发病(-1) 或当前 wave 在 onset 之后 (yearcont<0)
        yearcont_mask = (racogimp_label==-1) | (yearcont<0)
        yearcont[yearcont_mask] = -1
        result_df[f'r{wave}cogimpt'] = yearcont
    
    result_df['racogimp_label'] = racogimp_label
    
    return result_df



cognitive_results = compute_cogimp_labels(df_cognition,baseline)

识别到认知变量:
- 记忆: 9 个变量
- 定向: 9 个变量
- 执行功能: 8 个变量

计算结果:
- 总样本数: 19802
- 有有效认知数据的样本数: 18866
- 曾经发病的样本数: 3310
- 从未发病的样本数: 15556
- 无有效数据的样本数: 936

发病wave分布:
- Wave 1: 794 人
- Wave 2: 436 人
- Wave 3: 390 人
- Wave 4: 450 人
- Wave 5: 342 人
- Wave 7: 331 人
- Wave 8: 273 人
- Wave 9: 294 人

结果预览:
   cognitive_impairment_label  onset_wave  r1cognitive_impairment_time  \
0                        <NA>        <NA>                          NaN   
1                          -1        <NA>                          NaN   
2                          -1        <NA>                          NaN   
3                          -1        <NA>                          NaN   
4                          -1        <NA>                          NaN   
5                        <NA>        <NA>                          NaN   
6                          -1        <NA>                          NaN   
7                           1           3                          4.0   
8                          -1        <NA>      

# Step4. 筛选样本
1. 构造具有所有待研究变量的df（19802，186+2=188）

In [60]:
df_filtered_final = pd.concat([df[all_cols2], cognitive_results], axis=1)

# 修改手动添加的r3shlta的名字
df_filtered_final = df_filtered_final.rename(columns={"r3shlta": "r3shlt"})

# 统计变量个数
from temp import count_unique_vars
var_names, var_count = count_unique_vars(list(df_filtered_final))
print("变量处理完后的变量个数：", var_count)

# 显示认知障碍标签的分布
print(f"\n认知障碍标签分布:")
print(cognitive_results['cognitive_impairment_label'].value_counts().sort_index())

# 显示时间变量的一些例子
print(f"\n时间变量示例:")
sample_cols = ['cognitive_impairment_label', 'onset_wave', 'r1cognitive_impairment_time', 'r3cognitive_impairment_time', 'r5cognitive_impairment_time']
available_cols = [col for col in sample_cols if col in cognitive_results.columns]
print(cognitive_results[available_cols].head(10))

# 对于发病案例，显示时间变量的计算逻辑验证
print(f"\n发病案例的时间变量验证（发病在wave3的情况）:")
onset_wave3_mask = (cognitive_results['onset_wave'] == 3) & (cognitive_results['cognitive_impairment_label'] == 1)
if onset_wave3_mask.any():
    sample_case = cognitive_results[onset_wave3_mask].head(1)
    time_cols = [col for col in cognitive_results.columns if 'cognitive_impairment_time' in col]
    print("发病wave:", sample_case['onset_wave'].values[0])
    print("时间变量值:")
    for col in time_cols[:5]:  # 只显示前5个
        wave_num = col[1]  # 提取wave数字
        expected_time = (3 - int(wave_num)) * 2  # 期望的时间值
        actual_time = sample_case[col].values[0]
        print(f"  {col}: {actual_time} (期望: {expected_time})")

变量处理完后的变量个数： 196

认知障碍标签分布:
cognitive_impairment_label
-1    15556
1      3310
Name: count, dtype: Int64

时间变量示例:
   cognitive_impairment_label  onset_wave  r1cognitive_impairment_time  \
0                        <NA>        <NA>                          NaN   
1                          -1        <NA>                          NaN   
2                          -1        <NA>                          NaN   
3                          -1        <NA>                          NaN   
4                          -1        <NA>                          NaN   
5                        <NA>        <NA>                          NaN   
6                          -1        <NA>                          NaN   
7                           1           3                          4.0   
8                          -1        <NA>                          NaN   
9                        <NA>        <NA>                          NaN   

   r3cognitive_impairment_time  r5cognitive_impairment_time  
0        

2. 筛选样本：① cognitive_impairment_label不为NaN（有lable）；②稀疏度＜30%

In [61]:
# 计算每行稀疏度（缺失比例）
row_sparsity = df_filtered_final.isna().mean(axis=1)

# 打印基本统计
print("=== 稀疏度（缺失比例）统计 ===")
print(f"最小稀疏度: {row_sparsity.min():.3f}")
print(f"最大稀疏度: {row_sparsity.max():.3f}")
print(f"平均稀疏度: {row_sparsity.mean():.3f}")

# 看看分布（例如按区间统计）
print("\n=== 稀疏度分布（分桶）===")
print(row_sparsity.value_counts(bins=[0,0.3,0.5,0.7,0.9,1]))


=== 稀疏度（缺失比例）统计 ===
最小稀疏度: 0.071
最大稀疏度: 0.984
平均稀疏度: 0.603

=== 稀疏度分布（分桶）===
(0.7, 0.9]       4531
(0.9, 1.0]       4280
(-0.001, 0.3]    3811
(0.3, 0.5]       3759
(0.5, 0.7]       3421
Name: count, dtype: int64


In [None]:
# 首先筛选 global_cognitive_z_score 不为 NaN 的样本
condition1 = df_filtered_final['cognitive_impairment_label'].notna()
print(f"条件1 - cognitive_impairment_label不为NaN的样本数: {condition1.sum()}")

# 计算每行的非缺失值比例（即非稀疏度）
# 计算稀疏度：缺失值比例 = 1 - 非缺失值比例
sparsity = df_filtered_final.isna().mean(axis=1)
non_sparsity = 1 - sparsity

# 条件2：稀疏度 ＜ 30%
condition2 = sparsity < 0.3
print(f"条件2 - 稀疏度＜30%的样本数: {condition2.sum()}")

# 同时满足两个条件的样本
both_conditions = condition1 & condition2
print(f"同时满足两个条件的样本数: {both_conditions.sum()}")

# 应用筛选
df_filtered_final = df_filtered_final[both_conditions].copy()

print(f"\n筛选后的数据集形状: {df_filtered_final.shape}")
print(f"筛选后的样本数: {df_filtered_final.shape[0]}")

# 显示筛选后数据的基本信息
print(f"\n筛选后数据集的基本统计:")
print(f"- global_cognitive_z_score的描述统计:")
print(df_filtered_final['global_cognitive_z_score'].describe())
print(f"- 整体稀疏度统计:")
final_sparsity = df_filtered_final.isna().mean(axis=1)
print(f"  最小稀疏度: {final_sparsity.min():.3f}")
print(f"  最大稀疏度: {final_sparsity.max():.3f}")
print(f"  平均稀疏度: {final_sparsity.mean():.3f}")

条件1 - global_cognitive_z_score不为NaN的样本数: 8059
条件2 - 稀疏度＜30%的样本数: 3811
同时满足两个条件的样本数: 3287

筛选后的数据集形状: (3287, 1283)
筛选后的样本数: 3287

筛选后数据集的基本统计:
- global_cognitive_z_score的描述统计:
count    3287.000000
mean       -0.186756
std         1.090888
min        -5.859625
25%        -0.654722
50%        -0.012898
75%         0.498222
max         2.382079
Name: global_cognitive_z_score, dtype: float64
- 整体稀疏度统计:
  最小稀疏度: 0.071
  最大稀疏度: 0.299
  平均稀疏度: 0.194


# Step5. 缺失值处理
1. 非随机变量的中值插补

In [63]:
# 基础变量
vars_base = ["systo", "diasto", "pulse", "gripsum", "mheight", "mweight", "mbmi"]

df_filled = df_filtered_final.copy()

for v in vars_base:
    # --- wave1 = wave2 ---
    colname = f"r1{v}"
    refcol = f"r2{v}"
    df_filled.insert(df_filled.columns.get_loc(refcol), colname, df_filled[refcol])

    # --- wave3 = (wave2, wave4) 中位数 ---
    colname = f"r3{v}"
    refcol = f"r4{v}"
    df_filled.insert(df_filled.columns.get_loc(refcol), colname, 
                     df_filled[[f"r2{v}", f"r4{v}"]].median(axis=1, skipna=True))

    # --- wave5 = (wave4, wave6) 中位数 ---
    colname = f"r5{v}"
    refcol = f"r6{v}"
    df_filled.insert(df_filled.columns.get_loc(refcol), colname, 
                     df_filled[[f"r4{v}", f"r6{v}"]].median(axis=1, skipna=True))

    # --- wave7 = (wave6, wave8) 中位数 ---
    colname = f"r7{v}"
    refcol = f"r8{v}"
    df_filled.insert(df_filled.columns.get_loc(refcol), colname, 
                     df_filled[[f"r6{v}", f"r8{v}"]].median(axis=1, skipna=True))

    # --- wave9 = wave8 ---
    colname = f"r9{v}"
    refcol = f"r8{v}"  # 插在r8后面
    df_filled.insert(df_filled.columns.get_loc(refcol) + 1, colname, df_filled[refcol])


2. 整个数据集缺失值的多重插补

In [None]:
import os
from datetime import datetime
import pandas as pd
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.impute import SimpleImputer

# 你的原始数据
df_final = df_filled.copy()

# 插补次数
m = 5  

# 保存目录
save_dir = r"D:\AA_hias\projects\02ADHD\adhd\data\processed"
os.makedirs(save_dir, exist_ok=True)

# 拆分数值列 & 非数值列
num_cols = df_final.select_dtypes(include=["number"]).columns
cat_cols = df_final.select_dtypes(exclude=["number"]).columns

df_num = df_final[num_cols]
df_cat = df_final[cat_cols]

# 分类变量：用众数填充
imputer_cat = SimpleImputer(strategy="most_frequent")
df_cat_imputed = pd.DataFrame(
    imputer_cat.fit_transform(df_cat),
    columns=cat_cols
)

# 保存每个插补后的数据集
imputed_datasets = []

for i in range(m):
    # 数值变量多重插补·
    imputer_num = IterativeImputer(random_state=i)
    df_num_imputed = pd.DataFrame(
        imputer_num.fit_transform(df_num),
        columns=num_cols
    )

    # 合并：保持原始列顺序
    df_imputed = pd.concat([df_num_imputed, df_cat_imputed], axis=1)[df_final.columns]

    imputed_datasets.append(df_imputed)

    # 自动生成文件名
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"df_filtered_final_imputed_{i+1}_{timestamp}.csv"
    save_path = os.path.join(save_dir, filename)

    # 保存文件
    df_imputed.to_csv(save_path, index=False, encoding="utf-8-sig")

    print(f"已保存: {save_path}")

print(f"共生成并保存 {m} 个插补后的数据集")




已保存: D:\AA_hias\projects\02ADHD\adhd\data\processed\df_filtered_final_imputed_1_20250901_012512.csv




已保存: D:\AA_hias\projects\02ADHD\adhd\data\processed\df_filtered_final_imputed_2_20250901_015120.csv




已保存: D:\AA_hias\projects\02ADHD\adhd\data\processed\df_filtered_final_imputed_3_20250901_021211.csv




已保存: D:\AA_hias\projects\02ADHD\adhd\data\processed\df_filtered_final_imputed_4_20250901_023709.csv




已保存: D:\AA_hias\projects\02ADHD\adhd\data\processed\df_filtered_final_imputed_5_20250901_030243.csv
共生成并保存 5 个插补后的数据集


In [None]:
# import os
# from datetime import datetime

# # 指定保存目录
# save_dir = r"D:\AA_hias\projects\02ADHD\adhd\data\processed"  
# os.makedirs(save_dir, exist_ok=True)  # 如果目录不存在，就自动创建

# # 自动生成文件名
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# filename = f"df_filtered_final_imputed_{timestamp}.csv"

# # 拼接完整路径
# save_path = os.path.join(save_dir, filename)

# # 保存
# df_filtered_final_imputed.to_csv(save_path, index=False, encoding="utf-8-sig")

# print(f"文件已保存到: {save_path}")
