In [1]:
import pandas as pd
import numpy as np

from tqdm import tqdm

In [2]:
n_classes = 120
n_negatives = 100_000

n_positives = np.loadtxt('../../data/drug/raw/num_positives.txt', dtype=int)
assert n_positives.size == n_classes

with open('../../data/drug/raw/smiles.txt', 'r') as f:
    smiles = f.readlines()
    
smiles = [s.strip() for s in smiles]
assert len(smiles) == n_negatives + n_positives.sum()

neg_smiles = smiles[:n_negatives]
smiles = smiles[n_negatives:]
pos_smiles = []
for class_ind in range(n_classes):
    tmp_n_positives = n_positives[class_ind]
    
    pos_smiles.append(smiles[:tmp_n_positives])
    smiles = smiles[tmp_n_positives:]

In [3]:
group_size = 4
assert n_classes % group_size == 0

np.random.seed(group_size)
shuffled_classes = np.random.permutation(np.arange(n_classes))
np.savetxt(f'../../data/drug/split/{group_size}/shuffled_classes.txt', shuffled_classes, fmt='%d')

print(shuffled_classes)

[ 13   2  25  16  19  41   5  24  82  20  65  79  34  86  77 116  10  12
  92  26  97 101  85 117  84  35  76  47  64 113  54  29  75  14  93  31
  99  89  11   4  63  78  61  60  37 114  43  27 109  15  88 108  71  51
  53  96  18 111   7  81  80  39   6  74  91  70  68  62  67  22  23  48
  59  17 102 105  98  28  83  33  45  42  40  32 110  90  49   8  30 119
  66  56 100  73  95  21 106   0   3  52  38  44 112  36 115  57 107 118
  94 103  58   9  50  72  87 104   1  69  55  46]


In [4]:
for problem_ind in tqdm(range(n_classes // group_size)):
    tmp_smiles = []
    labels = []
    
    tmp_smiles += neg_smiles
    labels += [0] * n_negatives
    
    for class_ind in range(group_size):
        tmp_class = shuffled_classes[class_ind]
        
        tmp_smiles += pos_smiles[tmp_class]
        labels += [class_ind + 1] * n_positives[tmp_class]
        
        df = pd.DataFrame({'smiles': tmp_smiles, 'class': labels})
        df.to_csv(
            f'../../data/drug/split/{group_size}/problem{problem_ind}.csv', 
            index=False, header=False
        )

    shuffled_classes = shuffled_classes[group_size:]

100%|███████████████████████████████████████████████████████████████████████████████| 30/30 [00:20<00:00,  1.45it/s]
