In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import argparse
import os

In [2]:
min_num = 500
max_num = 10000

In [3]:
pkl_dir = '../out/'
data_dir = '../../Prophet/data/'
output_dir = '../out/'

In [None]:


# 输入文件
term2caseids_path = os.path.join(pkl_dir, 'term2caseids.pkl')
term2casedates_path = os.path.join(pkl_dir, 'term2casedates.pkl')
term2controlids_path = os.path.join(pkl_dir, 'term2controlids.pkl')
proteomics_data_path = os.path.join(data_dir, 'preprocessed_proteomics_data.csv')

# 输出文件
output_date_matrix_path = os.path.join(output_dir, f'label_date_matrix_{min_num}_{max_num}.npy')
output_mask_matrix_path = os.path.join(output_dir, f'label_mask_matrix_{min_num}_{max_num}.npy')

# --- 2. 加载预处理好的数据 ---
print("正在加载预处理好的 .pkl 文件...")
with open(term2caseids_path, 'rb') as f:
    term2caseids = pkl.load(f)
with open(term2casedates_path, 'rb') as f:
    term2casedates = pkl.load(f)
with open(term2controlids_path, 'rb') as f:
    term2controlids = pkl.load(f)

# 加载蛋白质组学数据以获取完整的参与者列表和顺序
print("正在加载参与者列表...")
proteomics_df = pd.read_csv(proteomics_data_path)

# 生成标签矩阵
eid2idx = {eid: idx for idx, eid in enumerate(proteomics_df.EID.values.astype(str))}
disease_list = [term for term in term2caseids if min_num <= len(term2caseids[term]) <= max_num]
print(f"找到 {len(disease_list)} 种合乎规定的疾病。")
label_date_matrix = np.full((len(proteomics_df), len(disease_list)), np.datetime64('NaT'), dtype='datetime64[D]')
label_mask_matrix = np.ones((len(proteomics_df), len(disease_list)), dtype=int)
for idx, term in enumerate(disease_list):
    case_eids = term2caseids[term]
    control_eids = term2controlids[term]
    label_date_matrix[[eid2idx[eid] for eid in case_eids], idx] = term2casedates[term]
    label_mask_matrix[[eid2idx[eid] for eid in case_eids], idx] = 0
    label_mask_matrix[[eid2idx[eid] for eid in control_eids], idx] = 0
np.save(output_date_matrix_path, label_date_matrix)
np.save(output_mask_matrix_path, label_mask_matrix)

In [7]:
proteomics_date = pd.read_csv('/home/dataset-assist-0/yaosen/lihan/ght/Prophet-Meta-temp/Prophet/data/proteomics_test_date.csv')
proteomics_date['proteomics_test_date'] = pd.to_datetime(proteomics_date['proteomics_test_date'])
id2proteindates = dict(zip(proteomics_date['EID'].astype(str), proteomics_date['proteomics_test_date']))

In [None]:
term2pre_cases = {}
term2pre_controls = {}

term2inc_cases = {}
term2inc_controls = {}

for term in term2caseids:
    case_eids = term2caseids[term]
    control_eids = term2controlids[term]
    pre_cases = []
    pre_controls = []
    inc_cases = []
    inc_controls = []
    for i,eid in enumerate(case_eids):
        date = term2casedates[term][i]
        if date < id2proteindates[eid]:
            pre_cases.append(eid)
        else:
            inc_cases.append(eid)
    # 对于controls，pre_controls需要去掉inc_cases中的eid, inc_controls同理，不需要考虑时间
    for eid in control_eids:
        if eid not in inc_cases:
            pre_controls.append(eid)
        if eid not in pre_cases:
            inc_controls.append(eid)
    term2pre_cases[term] = pre_cases
    term2pre_controls[term] = pre_controls
    term2inc_cases[term] = inc_cases
    term2inc_controls[term] = inc_controls
        
        

In [None]:
# save these files:
with open(os.path.join(pkl_dir, 'term2pre_cases.pkl'), 'wb') as f:
    pkl.dump(term2pre_cases, f)
with open(os.path.join(pkl_dir, 'term2pre_controls.pkl'), 'wb') as f:
    pkl.dump(term2pre_controls, f)
with open(os.path.join(pkl_dir, 'term2inc_cases.pkl'), 'wb') as f:
    pkl.dump(term2inc_cases, f)
with open(os.path.join(pkl_dir, 'term2inc_controls.pkl'), 'wb') as f:
    pkl.dump(term2inc_controls, f)


In [4]:
# load the files:
with open(os.path.join(pkl_dir, 'term2pre_cases.pkl'), 'rb') as f:
    term2pre_cases = pkl.load(f)
with open(os.path.join(pkl_dir, 'term2pre_controls.pkl'), 'rb') as f:
    term2pre_controls = pkl.load(f)
with open(os.path.join(pkl_dir, 'term2inc_cases.pkl'), 'rb') as f:
    term2inc_cases = pkl.load(f)
with open(os.path.join(pkl_dir, 'term2inc_controls.pkl'), 'rb') as f:
    term2inc_controls = pkl.load(f)


In [5]:
!pwd

/home/dataset-assist-0/yaosen/lihan/ght/Prophet-Meta-temp/data/scripts


In [8]:
eids = list(id2proteindates.keys())
# split eids into train valid and test sets with ratio of 8:1:1:
eids = np.random.permutation(eids)
train_eids = list(eids)[:int(0.8*len(eids))]
valid_eids = list(eids)[int(0.8*len(eids)):int(0.9*len(eids))]
test_eids = list(eids)[int(0.9*len(eids)):]

In [9]:
valid_pre_terms = [term for term in term2pre_cases if 1000 > len(term2pre_cases[term]) > 50 and len(term2pre_controls[term]) > 50]
valid_case_terms = [term for term in term2inc_cases if 1000 > len(term2inc_cases[term]) > 50 and len(term2inc_controls[term]) > 50]
len(valid_pre_terms), len(valid_case_terms)

(907, 1219)

In [10]:
term2pre_cases_train, term2pre_controls_train = {}, {}
term2pre_cases_valid, term2pre_controls_valid = {}, {}
term2pre_cases_test, term2pre_controls_test = {}, {}

for term in np.random.permutation(valid_pre_terms):
    case_eids = term2pre_cases[term]
    control_eids = term2pre_controls[term]
    overlap_cases = np.intersect1d(test_eids, case_eids)
    overlap_controls = np.intersect1d(test_eids, control_eids)
    if len(overlap_cases) > 50 and len(overlap_controls) > len(overlap_cases):
        term2pre_cases_test[term] = overlap_cases
        term2pre_controls_test[term] = overlap_controls
    if len(term2pre_cases_test) > 0.1*len(valid_pre_terms):
        break

In [11]:
for term in np.random.permutation(valid_pre_terms):
    if term in term2pre_cases_test:
        continue
    case_eids = term2pre_cases[term]
    control_eids = term2pre_controls[term]
    overlap_cases = np.intersect1d(valid_eids, case_eids)
    overlap_controls = np.intersect1d(valid_eids, control_eids)
    if len(overlap_cases) > 50 and len(overlap_controls) > len(overlap_cases):
        term2pre_cases_valid[term] = overlap_cases
        term2pre_controls_valid[term] = overlap_controls
    if len(term2pre_cases_valid) > 0.1*len(valid_pre_terms):
        break

In [12]:
for term in np.random.permutation(valid_pre_terms):
    if term in term2pre_cases_test or term in term2pre_cases_valid:
        continue
    case_eids = term2pre_cases[term]
    control_eids = term2pre_controls[term]
    overlap_cases = np.intersect1d(train_eids, case_eids)
    overlap_controls = np.intersect1d(train_eids, control_eids)
    if len(overlap_cases) > 32 and len(overlap_controls) > len(overlap_cases):
        term2pre_cases_train[term] = overlap_cases
        term2pre_controls_train[term] = overlap_controls

In [13]:
len(term2pre_cases_train),len(term2pre_cases_valid),len(term2pre_cases_test)

(788, 28, 91)

In [14]:
threshold = 0.75
def jaccard(a, b):
    a, b = set(a), set(b)
    if not a and not b:
        return 0.0
    return len(a & b) / len(a | b)

removed = []
# iterate over a static list of test terms to avoid modifying during iteration
for term in list(term2pre_cases_test.keys()):
    test_set = set(term2pre_cases[term])
    removed_flag = False
    # check against training terms
    for other_term in term2pre_cases_train.copy():
        other_cases = term2pre_cases[other_term]
        j = jaccard(test_set, other_cases)
        if j > threshold:
            removed.append((term, 'train', other_term, j))
            term2pre_cases_train.pop(other_term, None)
            term2pre_controls_train.pop(other_term, None)
            removed_flag = True
    # check against validation terms
    for other_term in term2pre_cases_valid.copy():
        other_cases = term2pre_cases[other_term]
        j = jaccard(test_set, other_cases)
        if j > threshold:
            removed.append((term, 'valid', other_term, j))
            term2pre_cases_valid.pop(other_term, None)
            term2pre_controls_valid.pop(other_term, None)

print(f"Removed {len(removed)} terms from train/valid set due to high Jaccard overlap (>{threshold}) with test sets.")

Removed 8 terms from train/valid set due to high Jaccard overlap (>0.75) with test sets.


In [15]:
removed = []
# iterate over a static list of test terms to avoid modifying during iteration
for term in list(term2pre_cases_valid.keys()):
    valid_set = set(term2pre_cases[term])
    for other_term in term2pre_cases_train.copy():
        other_cases = term2pre_cases[other_term]
        j = jaccard(valid_set, other_cases)
        if j > threshold:
            removed.append((term, 'valid', other_term, j))
            term2pre_cases_train.pop(other_term, None)
            term2pre_controls_train.pop(other_term, None)

print(f"Removed {len(removed)} terms from train set due to high Jaccard overlap (>{threshold}) with test sets.")


Removed 2 terms from train set due to high Jaccard overlap (>0.75) with test sets.


In [None]:
with open(os.path.join(pkl_dir, 'term2pre_cases_train.pkl'), 'wb') as f:
    pkl.dump(term2pre_cases_train, f)
with open(os.path.join(pkl_dir, 'term2pre_cases_valid.pkl'), 'wb') as f:
    pkl.dump(term2pre_caases_valid, f)
with open(os.path.join(pkl_dir, 'term2pre_cases_test.pkl'), 'wb') as f:
    pkl.dump(term2pre_cases_test, f)
    
with open(os.path.join(pkl_dir, 'term2pre_controls_train.pkl'), 'wb') as f:
    pkl.dump(term2pre_controls_train, f)
with open(os.path.join(pkl_dir, 'term2pre_controls_valid.pkl'), 'wb') as f:
    pkl.dump(term2pre_controls_valid, f)
with open(os.path.join(pkl_dir, 'term2pre_controls_test.pkl'), 'wb') as f:
    pkl.dump(term2pre_controls_test, f)

In [18]:
len(term2pre_cases_train),len(term2pre_cases_valid),len(term2pre_cases_test)

(783, 23, 91)