In [None]:
import numpy as np
import pandas as pd 


# Intro
In this work, shows how to split the 5-fold dataset of Davis, KIBA and Metz.

There are 4 types of task setting:\
(d: The drug in test set, t: The target in test set)
1. Warm set\
    Both d and t can be appear in training set.
2. Cold Drug\
    The d is absent from training set.
3. Cold Protein\
    The t is absent from training set.
4. Cold Pair\
    Both d and t are absent from training set.

**Input**\
[dta-origin-dataset](https://www.kaggle.com/datasets/christang0002/llmdta/data)
- davis.txt
- kiba.txt
- metz.txt

**Output**\
[dta-5fold-dataset](https://www.kaggle.com/datasets/christang0002/llmdta/data)
- *davis*
    - *warm*
        - fold_0_train.csv
        - ...
    - *novel-drug*
    - *novel-prot*
    - *novel-pair*
    - davis_drugs.csv
    - davis_prots.csv
    - davis_pairs.cav
- *kiba*
- *metz*

In [None]:
davis_dir = '/kaggle/input/davis-and-kiba/davis.txt'
kiba_dir = '/kaggle/input/davis-and-kiba/kiba.txt'
metz_dir = '/kaggle/input/metz-dta/metz.txt'

In [None]:
random_seed = 0
col_name = ['drug_id', 'prot_id', 'drug_seq', 'prot_seq', 'affinity']
cur_data = 'davis'  # TODO modify the dataset name
df = pd.read_csv(davis_dir, sep=' ', header=None)  # TODO modify the dataset_dir
df.columns = col_name

In [None]:
df.head()

# Split Drugs、Prots and Pairs

In [None]:
df.shape

In [None]:
df_drugs = df.loc[:,['drug_id', 'drug_seq']].drop_duplicates()
df_prots = df.loc[:,['prot_id', 'prot_seq']].drop_duplicates()
df_pairs = df.loc[:,['drug_id', 'prot_id', 'affinity']]
df_pairs = df_pairs.sample(frac=1, random_state=random_seed).reset_index(drop=True)
print(df_drugs.shape)
print(df_prots.shape)
print(df_pairs.shape)

In [None]:
import os 
path = f'./{cur_data}'
if not os.path.exists(path):
    os.makedirs(path)
    print(f'Create path {path}')

In [None]:
df_drugs.to_csv(f'./{cur_data}/{cur_data}_drugs.csv', index=False, header=True) 
df_prots.to_csv(f'./{cur_data}/{cur_data}_prots.csv', index=False, header=True)
df_pairs.to_csv(f'./{cur_data}/{cur_data}_pairs.csv', index=False, header=True) 

In [None]:
df_pairs[:5]

# Warm Setting

In [None]:
path = f'./{cur_data}/warm'
if not os.path.exists(path):
    os.makedirs(path)
    print(f'Create path {path}')

In [None]:
from sklearn.model_selection import train_test_split

k = 5
fold_size = len(df_pairs) // k
for i in range(k):
    test_start = i * fold_size
    if i != k - 1 and i != 0:
        test_end = (i + 1) * fold_size
        testset = df_pairs[test_start:test_end]
        tvset = pd.concat([df_pairs[0:test_start], df_pairs[test_end:]])
    elif i == 0:
        test_end = fold_size
        testset = df_pairs[test_start:test_end]
        tvset = df_pairs[test_end:]
    else:
        testset = df_pairs[test_start:]
        tvset = df_pairs[0:test_start]
    
    # split training-set and valid-set
    trainset, validset = train_test_split(tvset, test_size=0.2, random_state=0)
    print(f'train:{len(trainset)}, valid:{len(validset)}, test:{len(testset)}')
    trainset.to_csv(f'./{cur_data}/warm/fold_{i}_train.csv', index=False, header=True) 
    validset.to_csv(f'./{cur_data}/warm/fold_{i}_valid.csv', index=False, header=True)
    testset.to_csv(f'./{cur_data}/warm/fold_{i}_test.csv', index=False, header=True)

# Novel Drug

In [None]:
path = f'./{cur_data}/novel-drug'
if not os.path.exists(path):
    os.makedirs(path)
    print(f'Create path {path}')

In [None]:
k = 5
drugs_num = len(df_drugs)
fold_size = drugs_num // k

for i in range(k):
    test_start = i * fold_size
    if i == k-1:
        test_end = drugs_num
    else:
        test_end = (i + 1) * fold_size        
    
    drugs_id = df_drugs[test_start:test_end]['drug_id']
    testset = df_pairs[df_pairs['drug_id'].isin(drugs_id)]
    tvset = df_pairs[~df_pairs['drug_id'].isin(drugs_id)]
    trainset, validset = train_test_split(tvset, test_size=0.2, random_state=0)
    
    print(f'train:{len(trainset)}, valid:{len(validset)}, test:{len(testset)}')
    trainset.to_csv(f'./{cur_data}/novel-drug/fold_{i}_train.csv', index=False, header=True) 
    validset.to_csv(f'./{cur_data}/novel-drug/fold_{i}_valid.csv', index=False, header=True)
    testset.to_csv(f'./{cur_data}/novel-drug/fold_{i}_test.csv', index=False, header=True)

# Novel Target

In [None]:
path = f'./{cur_data}/novel-prot'
if not os.path.exists(path):
    os.makedirs(path)
    print(f'Create path {path}')

In [None]:
k = 5
prots_num = len(df_prots)
fold_size = prots_num // k

for i in range(k):
    test_start = i * fold_size
    if i == k-1:
        test_end = prots_num
    else:
        test_end = (i + 1) * fold_size        
    
    prots_id = df_prots[test_start:test_end]['prot_id']    
    testset = df_pairs[df_pairs['prot_id'].isin(prots_id)]
    tvset = df_pairs[~df_pairs['prot_id'].isin(prots_id)]
    trainset, validset = train_test_split(tvset, test_size=0.2, random_state=0)
    
    print(f'train:{len(trainset)}, valid:{len(validset)}, test:{len(testset)}')
    trainset.to_csv(f'./{cur_data}/novel-prot/fold_{i}_train.csv', index=False, header=True) 
    validset.to_csv(f'./{cur_data}/novel-prot/fold_{i}_valid.csv', index=False, header=True)
    testset.to_csv(f'./{cur_data}/novel-prot/fold_{i}_test.csv', index=False, header=True)

# Novel Pair

In [None]:
path = f'./{cur_data}/novel-pair'
if not os.path.exists(path):
    os.makedirs(path)
    print(f'Create path {path}')

In [None]:
k = 5
for seed_i in range(k):    
    drugs_id = df_drugs.sample(frac=0.4, random_state=seed_i).reset_index(drop=True)
    prots_id = df_prots.sample(frac=0.4, random_state=seed_i).reset_index(drop=True)
    
    testset = df_pairs[(df_pairs['drug_id'].isin(drugs_id['drug_id'])) & (df_pairs['prot_id'].isin(prots_id['prot_id']))]
    trainset = df_pairs[(~df_pairs['drug_id'].isin(drugs_id['drug_id'])) & (~df_pairs['prot_id'].isin(prots_id['prot_id']))]    
    merged_df = pd.merge(testset, trainset, on=['drug_id', 'prot_id'], how='outer', indicator=True)
    validset = df_pairs[~df_pairs.index.isin(merged_df.index)]
    
    print(f'train:{len(trainset)}, valid:{len(validset)}, test:{len(testset)}')
    trainset.to_csv(f'./{cur_data}/novel-pair/fold_{seed_i}_train.csv', index=False, header=True) 
    validset.to_csv(f'./{cur_data}/novel-pair/fold_{seed_i}_valid.csv', index=False, header=True)
    testset.to_csv(f'./{cur_data}/novel-pair/fold_{seed_i}_test.csv', index=False, header=True)

In [None]:
df_pairs['prot_id'].value_counts()