In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../..')))
from seq2seq import *
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from collections import Counter
import random

In [2]:
# Load dataset into a pandas DataFrame
df = pd.read_csv('../../../Data/splittedATC.csv')

In [3]:
# Convert a string that simulates a list to a real list
def convert_string_list(element):
    # Delete [] of the string
    element = element[0:len(element)]
    # Create a list that contains each code as e.g. 'A'
    ATC_list = list(element.split('; '))
    for index, code in enumerate(ATC_list):
        # Delete '' of the code
        ATC_list[index] = code[0:len(code)]
    return ATC_list

In [4]:
def set_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
def create_partitions(df, seed):
    # Create a new column that indicates if the compound has more than 1 ATC code associated (1) or not (0)
    df['multiple_ATC'] = df['ATC Codes'].apply(lambda x: len(convert_string_list(x)) > 1)
    
    # Divide the dataset depending on multiple_ATC column
    group_more_than_one = df[df['multiple_ATC']]  # Compounds with more than one ATC code associated
    group_one = df[~df['multiple_ATC']]          # Compounds with just one ATC code associated

    conteo_longitudes = Counter(len(convert_string_list(codes)) for codes in group_more_than_one['ATC Codes'])
    group_more_than_one = group_more_than_one.reset_index(drop=True)
    group_one = group_one.reset_index(drop=True)

    train_set = pd.DataFrame(columns = group_more_than_one.columns.values)
    test_set = pd.DataFrame(columns = group_more_than_one.columns.values)
    for index, atc in enumerate(group_more_than_one['ATC Codes']):
        atc_list = convert_string_list(atc)
        
        if len(atc_list) == 2:
            elems_train = [random.choice(atc_list)]
            for x in elems_train:
                atc_list.remove(x)
            elems_test = atc_list[0]
            elems_train = "; ".join(elems_train)
        elif len(atc_list)%2 == 0:
            num_to_select = len(atc_list)//2
            elems_train = random.sample(atc_list, num_to_select)
            for x in elems_train:
                atc_list.remove(x)
            elems_test = atc_list
            elems_train = "; ".join(elems_train)
            elems_test = "; ".join(elems_test)
        else:
            num_to_select = len(atc_list)//2 + 1
            elems_train = random.sample(atc_list, num_to_select)
            for x in elems_train:
                atc_list.remove(x)
            elems_test = atc_list
            elems_train = "; ".join(elems_train)
            elems_test = "; ".join(elems_test)
        # Copiar todos los valores de la fila de group_more_than_one, pero modificar solo la columna 'ATC Codes'
        row_train = group_more_than_one.iloc[index].copy()
        row_test = group_more_than_one.iloc[index].copy()
        
        row_train['ATC Codes'] = elems_train
        row_test['ATC Codes'] = elems_test
        
        # Agregar la fila modificada a train_set y test_set
        train_set.loc[len(train_set)] = row_train
        test_set.loc[len(test_set)] = row_test

    train_setSPECIAL = train_set
    train_setSPECIAL.to_csv(f'Multiples_train_set{seed}.csv', index=False)

    # Combine each set
    # Drugs with only 1 ATC code are assigned to the training set
    train_one, val_one = train_test_split(group_one, test_size=0.15, random_state=seed)
    train_more, val_more = train_test_split(train_set, test_size=0.15, random_state=seed)
    
    train_set = pd.concat([train_one, train_more])
    val_set = pd.concat([val_one, val_more])
    
    train_set = shuffle(train_set, random_state = seed)
    val_set = shuffle(val_set, random_state = seed)
    
    complete_train_set = pd.concat([train_set, val_set])
    complete_train_set = shuffle(complete_train_set, random_state = seed)

    train_set.to_csv(f'Rep_train_set{seed}.csv', index = False)
    test_set.to_csv(f'Rep_test_set{seed}.csv', index = False)
    val_set.to_csv(f'Rep_val_set{seed}.csv', index = False)
    complete_train_set.to_csv(f'Rep_complete_train_set{seed}.csv', index = False)

In [5]:
seeds = [42, 123, 47899, 2025, 1, 20, 99, 1020, 345, 78] 

for seed in seeds:
    set_seeds(seed)
    create_partitions(df, seed)