In [None]:
import pandas as pd
import pickle
from torch.utils.data import random_split  
from dldd.utils import TwoGraphData

In [None]:
filehandler = open('final_list.pkl', 'rb') 
data_list = pickle.load(filehandler)
data_list = [TwoGraphData(**x) for x in data_list]

In [None]:

def split_number(number:int, train_frac=0.7, val_frac=0.2):
    train = int(train_frac * number)
    val = int(val_frac * number)
    test = number - train - val
    return train, val, test

def split_data_random(dataset, train_frac = 0.7, val_frac = 0.2):      ### spliting the data randomly
    train, val, test = split_number(len(dataset), train_frac, val_frac)
    return random_split(dataset,[train,val,test]) 

def split_cold_drug_cold_prot(dataset, drug, prot, train_frac=0.7, val_frac=0.2):                        ### Cold split of both proteins and drugs                 
    drug_train, drug_val, drug_test = random_split(drug, split_number(len(drug), train_frac, val_frac))
    prot_train, prot_val, prot_test = random_split(prot, split_number(len(prot), train_frac, val_frac))
    train = []
    val = []
    test = []
    for i in dataset:
        if i['accession'] in prot_train and i['smiles'] in drug_train:
            train.append(i)
        elif i['accession'] in prot_test or i['smiles'] in drug_test:
            test.append(i)
        else:
            val.append(i)
    return train, val, test

def split_cold_drug(dataset, drug, train_frac=0.7, val_frac=0.2):                         ### Cold split of drugs      
    drug_train, drug_val, drug_test = random_split(drug, split_number(len(drug), train_frac, val_frac))
    train = []
    val = []
    test = []
    for i in dataset:
        if i['smiles'] in drug_train:
            train.append(i)
        elif i['smiles'] in drug_test:
            test.append(i)
        else:
            val.append(i)
    return train, val, test

def split_cold_prot(dataset, prot, train_frac=0.7, val_frac=0.2):                          ### Cold split of proteins
    prot_train, prot_val, prot_test = random_split(prot, split_number(len(prot), train_frac, val_frac))
    train = []
    val = []
    test = []
    for i in dataset:
        if i['accession'] in prot_train:
            train.append(i)
        elif i['accession'] in prot_val:
            val.append(i)
        else:
            test.append(i)
    return train, val, test

In [None]:
prot = list(set([x['accession'] for x in data_list]))
drug = list(set([x['smiles'] for x in data_list]))


In [None]:
train, val, test = split_data_random(data_list)        # preforming the function

In [None]:
train, val, test = split_cold_drug(data_list, drug)        # preforming the function

In [None]:
train, val, test = split_cold_prot(data_list, prot)        # preforming the function

In [None]:
train, val, test = split_cold_drug_cold_prot(data_list,drug, prot)        # preforming the function

In [None]:
!mkdir -p data

In [None]:
final_list = open('data/train.pkl', 'wb')         # saving the first results of the split_data() function
pickle.dump(train, final_list)
final_list = open('data/val.pkl', 'wb') 
pickle.dump(val, final_list)
final_list = open('data/test.pkl', 'wb') 
pickle.dump(test, final_list)