In [9]:
import pandas as pd
import numpy as np

In [10]:
folds = [0,1,2,3,4]

A375_path = "../data/original_data/A375.csv"
A549_path = "../data/original_data/A549.csv"
Jurkat_path = "../data/original_data/Jurkat.csv"



A375_df = pd.read_csv(A375_path)
A549_df = pd.read_csv(A549_path)
Jurkat_df = pd.read_csv(Jurkat_path)

Jurkat_df['label'] = Jurkat_df['label'].astype(int)

In [11]:
A375_df

Unnamed: 0,gene1_name,gene2_name,cell_line,label
0,MAP2K2,UBC,A375,0
1,HSP90AA1,MAP2K2,A375,0
2,CHEK2,MAP2K2,A375,0
3,AKT1,MAP2K2,A375,0
4,BCL2,MAP2K2,A375,0
...,...,...,...,...
530,MAPK3,WEE1,A375,1
531,BRCA2,MAPK3,A375,1
532,MAPK3,MCL1,A375,1
533,HSP90AA1,WEE1,A375,1


In [12]:
# 找到重复的行
duplicates = A375_df[A375_df.duplicated(subset=['gene1_name', 'gene2_name'], keep=False)]

# 显示结果
print(duplicates)
print(len(duplicates))

    gene1_name gene2_name cell_line  label
66       BRCA2      MAPK3      A375      0
129   HSP90AA1       WEE1      A375      0
130   HSP90AA1      PARP1      A375      0
133   HSP90AA1       MTOR      A375      0
144     BCL2L1      PARP1      A375      0
145       MTOR      PARP1      A375      0
146       MCL1      PARP1      A375      0
156     BCL2L1       MCL1      A375      0
157      MAPK3       MCL1      A375      0
179      MAPK3       WEE1      A375      0
180     BCL2L1      MAPK3      A375      0
216     BCL2L1       WEE1      A375      0
270      BRCA1      PARP1      A375      0
273      BRCA2      PARP1      A375      0
521     BCL2L1       MCL1      A375      1
522   HSP90AA1      PARP1      A375      1
523     BCL2L1      MAPK3      A375      1
524       MCL1      PARP1      A375      1
525      BRCA2      PARP1      A375      1
526      BRCA1      PARP1      A375      1
527     BCL2L1       WEE1      A375      1
528     BCL2L1      PARP1      A375      1
529       M

In [13]:
def filter_df(df):
    # 根据 gene1_name 和 gene2_name 进行分组
    grouped = df.groupby(['gene1_name', 'gene2_name'])
    
    # 用于存储过滤后的行的列表
    filtered_rows = []
    
    for _, group in grouped:
        # 检查组内是否存在 label 为 1 的行
        has_label_1 = any(group['label'] == 1)
        
        if has_label_1:
            # 将 label 为 0 的行过滤掉，保留 label 为 1 的行
            filtered_group = group[group['label'] == 1]
        else:
            # 如果组内没有 label 为 1 的行，则保留一个 label 为 0 的行
            filtered_group = group[group['label'] == 0].head(1)
        
        # 将过滤后的行添加到结果列表中
        filtered_rows.append(filtered_group)
    
    # 将过滤后的行合并成一个 DataFrame
    filtered_df = pd.concat(filtered_rows)
    
    return filtered_df

A375_filtered = filter_df(A375_df)
A549_filtered = filter_df(A549_df)
Jurkat_filtered = filter_df(Jurkat_df)

In [14]:
# save to csv
save_path = "../data/train_data"

A375_filtered.to_csv(f"{save_path}/A375_filtered.csv", index=False)
A549_filtered.to_csv(f"{save_path}/A549_filtered.csv", index=False)
Jurkat_filtered.to_csv(f"{save_path}/Jurkat_filtered.csv", index=False)

In [15]:
# split train and valid
# 5 fold
def split_train_valid(df, n_folds=5):
    # 将数据集按照 gene1_name 和 gene2_name 进行分组
    grouped = df.groupby(['gene1_name', 'gene2_name'])
    
    # 用于存储划分结果的列表
    splits = []
    
    for _, group in grouped:
        # 将组内的数据按照 gene1_name 和 gene2_name 进行分组后，随机划分为 n_folds 份
        group_splits = np.random.randint(0, n_folds, size=len(group))
        
        # 将划分结果添加到结果列表中
        splits.append(group_splits)
    
    # 将划分结果合并成一个 Series
    splits = pd.Series(np.concatenate(splits), index=df.index)
    
    return splits

for fold in folds:
    A375_splits = split_train_valid(A375_filtered)
    A549_splits = split_train_valid(A549_filtered)
    Jurkat_splits = split_train_valid(Jurkat_filtered)
    
    A375_filtered['split'] = A375_splits
    A549_filtered['split'] = A549_splits
    Jurkat_filtered['split'] = Jurkat_splits
    
    A375_train = A375_filtered[A375_filtered['split'] != fold]
    A375_train = A375_train.drop(columns=['split'])
    A375_valid = A375_filtered[A375_filtered['split'] == fold]
    A375_valid = A375_valid.drop(columns=['split'])
    
    A549_train = A549_filtered[A549_filtered['split'] != fold]
    A549_train = A549_train.drop(columns=['split'])
    A549_valid = A549_filtered[A549_filtered['split'] == fold]
    A549_valid = A549_valid.drop(columns=['split'])
    
    Jurkat_train = Jurkat_filtered[Jurkat_filtered['split'] != fold]
    Jurkat_train = Jurkat_train.drop(columns=['split'])
    Jurkat_valid = Jurkat_filtered[Jurkat_filtered['split'] == fold]
    Jurkat_valid = Jurkat_valid.drop(columns=['split'])
    
    # A375_train.to_csv(f"{save_path}/Cell_line_specific/A375/train_{fold}.csv", index=False)
    # A375_valid.to_csv(f"{save_path}/Cell_line_specific/A375/valid_{fold}.csv", index=False)
    
    # A549_train.to_csv(f"{save_path}/Cell_line_specific/A549/train_{fold}.csv", index=False)
    # A549_valid.to_csv(f"{save_path}/Cell_line_specific/A549/valid_{fold}.csv", index=False)
    
    Jurkat_train.to_csv(f"{save_path}/Cell_line_specific/Jurkat/train_{fold}.csv", index=False)
    Jurkat_valid.to_csv(f"{save_path}/Cell_line_specific/Jurkat/valid_{fold}.csv", index=False)