# NOW utils/split_dataset.py

In [156]:
import pandas as pd
import numpy as np
import re

In [157]:
SEED = 112
db = pd.read_csv('db.csv')

In [158]:
db

Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class
0,AAAAAAAAAAGIGKFLHSAKKFGKAFVGEIMNS,125.878150,pepID1,8
1,AAAAAAAIKMLMDLVNERIMALNKKAKK_amd,10.000000,pepID2,5
2,AAAAGSVWGAVNYTSDCNGECKRRGYKGGYCGSFANVNCWCET,100.000000,pepID3,8
3,AAAKAALNAVLVGANA,80.000000,pepID4,8
4,AACSDRAHGHICESFKSFCKDSGRNGVKLRANCKKTCGLC,1.780176,pepID5,2
...,...,...,...,...
15298,TmaTmaBmamTmaTmaTmaTmaPheTmaTmaTmaPheTmaTmaTma...,64.000000,polyID43_S96,7
15299,TmaTmaTmaTmaTmaTmaPheTmaTmaMepBmamBmamTmaTmaTm...,64.000000,polyID43_S97,7
15300,TmaTmaTmaOlamTmaTmaMepOlamPheMepTmaMepTmaTmaTm...,64.000000,polyID43_S98,7
15301,TmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaBmamTma...,64.000000,polyID43_S99,7


In [159]:
poly_idx = db['ID'].str.contains('poly')
db['group'] = db[poly_idx]['ID'].apply(lambda x: int(re.split(r'_', x)[0][6:]))

In [160]:
db

Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
0,AAAAAAAAAAGIGKFLHSAKKFGKAFVGEIMNS,125.878150,pepID1,8,
1,AAAAAAAIKMLMDLVNERIMALNKKAKK_amd,10.000000,pepID2,5,
2,AAAAGSVWGAVNYTSDCNGECKRRGYKGGYCGSFANVNCWCET,100.000000,pepID3,8,
3,AAAKAALNAVLVGANA,80.000000,pepID4,8,
4,AACSDRAHGHICESFKSFCKDSGRNGVKLRANCKKTCGLC,1.780176,pepID5,2,
...,...,...,...,...,...
15298,TmaTmaBmamTmaTmaTmaTmaPheTmaTmaTmaPheTmaTmaTma...,64.000000,polyID43_S96,7,43.0
15299,TmaTmaTmaTmaTmaTmaPheTmaTmaMepBmamBmamTmaTmaTm...,64.000000,polyID43_S97,7,43.0
15300,TmaTmaTmaOlamTmaTmaMepOlamPheMepTmaMepTmaTmaTm...,64.000000,polyID43_S98,7,43.0
15301,TmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaBmamTma...,64.000000,polyID43_S99,7,43.0


In [161]:
start_idx = max(db['group'].value_counts().index.to_list())

vals = list(db['group'].value_counts().to_list())

if np.unique(vals).size == 1:
    k = vals[0]

In [162]:
pep_idx = db['ID'].str.startswith('pepID')

# Shuffle and assign groups directly
db.loc[pep_idx, 'group'] = (
    db[pep_idx]
    .sample(frac=1, random_state=SEED)  # Shuffle only the matching rows
    .assign(group=lambda x: (np.arange(len(x)) // k) + start_idx + 1)['group']  # Group by size 2
)

In [163]:
db['group'] = db['group'].astype(int) 
db

Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
0,AAAAAAAAAAGIGKFLHSAKKFGKAFVGEIMNS,125.878150,pepID1,8,67
1,AAAAAAAIKMLMDLVNERIMALNKKAKK_amd,10.000000,pepID2,5,61
2,AAAAGSVWGAVNYTSDCNGECKRRGYKGGYCGSFANVNCWCET,100.000000,pepID3,8,46
3,AAAKAALNAVLVGANA,80.000000,pepID4,8,76
4,AACSDRAHGHICESFKSFCKDSGRNGVKLRANCKKTCGLC,1.780176,pepID5,2,112
...,...,...,...,...,...
15298,TmaTmaBmamTmaTmaTmaTmaPheTmaTmaTmaPheTmaTmaTma...,64.000000,polyID43_S96,7,43
15299,TmaTmaTmaTmaTmaTmaPheTmaTmaMepBmamBmamTmaTmaTm...,64.000000,polyID43_S97,7,43
15300,TmaTmaTmaOlamTmaTmaMepOlamPheMepTmaMepTmaTmaTm...,64.000000,polyID43_S98,7,43
15301,TmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaTmaBmamTma...,64.000000,polyID43_S99,7,43


In [164]:
def shuffle_groups(df, group_column):
    # Get unique groups and shuffle them
    shuffled_groups = np.random.RandomState(seed=SEED).permutation(df[group_column].unique())
    
    # Reorder the DataFrame based on the shuffled groups
    reordered_df = pd.concat([df[df[group_column] == group] for group in shuffled_groups])
    
    return reordered_df.reset_index(drop=True)

# Apply the function
df_shuffled = shuffle_groups(db, 'group')

# View the shuffled DataFrame
df_shuffled

Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
0,AIPWSIWWHLLFKG,50.00,pepID94,7,105
1,AKRLKKLAKKIWKWK_amd,2.00,pepID134,2,105
2,ALFKTMLKKLGTMAL_amd,4.50,pepID141,4,105
3,ALRSAVRTVARVGRAVLPHVAI_amd,6.30,pepID160,4,105
4,FIGLLISAGKAIHDLIRRRH,0.25,pepID592,1,105
...,...,...,...,...,...
15298,MNLEVVVQLGSLSLIVLAGPIIVLLLASQKGNL,9999.00,pepID10711,11,145
15299,QYVKDPDKQVVARIFLDLQLVQR,9999.00,pepID10774,11,145
15300,LGIYSDAGTQTCSK,9999.00,pepID10832,11,145
15301,EDAYKFTTW,9999.00,pepID10897,11,145


In [280]:
def get_idx(start_idx, set_len):
    end_row = df_shuffled.iloc[start_idx + set_len, :]
    print(end_row)
    end_group = end_row['group']
    print(end_group)
    end_idx = end_row.name
    print(end_idx)
    
    end_group_idx = df_shuffled['group'][df_shuffled['group'] == end_group].index.to_list()
    print(end_group_idx)

    if (end_idx - min(end_group_idx)) >= 0.5*len(end_group_idx):
        print(max(end_group_idx))
        return max(end_group_idx)
    else:
        print(min(end_group_idx)-1)
        return min(end_group_idx) - 1


SPLIT = '0.4,0.3,0.3'  # train, val, test
SPLIT = re.split(',', SPLIT)
SPLIT = list(map(float, SPLIT))  

# test set
train_start_idx = 0
train_set_len = int(np.around(len(df_shuffled)*SPLIT[0]))
train_end_idx = get_idx(train_start_idx, train_set_len) + 1

train_set = df_shuffled.iloc[train_start_idx:train_end_idx,:]
train_set

sequence           GMWSKILKHLIR
MIC_ecoli                   2.0
ID                    pepID2179
MIC_ecoli_class               2
group                       149
Name: 6121, dtype: object
149
6121
[6100, 6101, 6102, 6103, 6104, 6105, 6106, 6107, 6108, 6109, 6110, 6111, 6112, 6113, 6114, 6115, 6116, 6117, 6118, 6119, 6120, 6121, 6122, 6123, 6124, 6125, 6126, 6127, 6128, 6129, 6130, 6131, 6132, 6133, 6134, 6135, 6136, 6137, 6138, 6139, 6140, 6141, 6142, 6143, 6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151, 6152, 6153, 6154, 6155, 6156, 6157, 6158, 6159, 6160, 6161, 6162, 6163, 6164, 6165, 6166, 6167, 6168, 6169, 6170, 6171, 6172, 6173, 6174, 6175, 6176, 6177, 6178, 6179, 6180, 6181, 6182, 6183, 6184, 6185, 6186, 6187, 6188, 6189, 6190, 6191, 6192, 6193, 6194, 6195, 6196, 6197, 6198, 6199]
6099


Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
0,AIPWSIWWHLLFKG,50.00,pepID94,7,105
1,AKRLKKLAKKIWKWK_amd,2.00,pepID134,2,105
2,ALFKTMLKKLGTMAL_amd,4.50,pepID141,4,105
3,ALRSAVRTVARVGRAVLPHVAI_amd,6.30,pepID160,4,105
4,FIGLLISAGKAIHDLIRRRH,0.25,pepID592,1,105
...,...,...,...,...,...
6095,GIPCAESCVYIPCITAALGCSCKNKVCYRN,9999.00,pepID10507,11,136
6096,VIGGDICNINEHNFLVALYE,9999.00,pepID10689,11,136
6097,GVIPCGESCVFIPCISSVVGCTCKNKVCYRD,9999.00,pepID10737,11,136
6098,GFPTCGETCTLGTCNTPGCTCSWPICTRD,9999.00,pepID10770,11,136


In [281]:
# val set
val_start_idx = train_end_idx
val_set_len = int(np.around(len(df_shuffled)*SPLIT[1]))
val_end_idx = get_idx(val_start_idx, val_set_len) + 1

val_set = df_shuffled.iloc[val_start_idx:val_end_idx,:]
val_set

sequence           TmaPheTmaTmaTmaTmaTmaOlamTmaTmaTmaTmaTmaTmaTma...
MIC_ecoli                                                      128.0
ID                                                      polyID36_S92
MIC_ecoli_class                                                    8
group                                                             36
Name: 10691, dtype: object
36
10691
[10600, 10601, 10602, 10603, 10604, 10605, 10606, 10607, 10608, 10609, 10610, 10611, 10612, 10613, 10614, 10615, 10616, 10617, 10618, 10619, 10620, 10621, 10622, 10623, 10624, 10625, 10626, 10627, 10628, 10629, 10630, 10631, 10632, 10633, 10634, 10635, 10636, 10637, 10638, 10639, 10640, 10641, 10642, 10643, 10644, 10645, 10646, 10647, 10648, 10649, 10650, 10651, 10652, 10653, 10654, 10655, 10656, 10657, 10658, 10659, 10660, 10661, 10662, 10663, 10664, 10665, 10666, 10667, 10668, 10669, 10670, 10671, 10672, 10673, 10674, 10675, 10676, 10677, 10678, 10679, 10680, 10681, 10682, 10683, 10684, 10685, 10686, 10687, 10

Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
6100,AKRLKKLAKKIWKWK,4.0,pepID133,3,149
6101,ALWKNMLKGIGK_amd,6.0,pepID181,4,149
6102,ALWKTIIKGAGKMIGSLAKNLLGSQAQPES,50.0,pepID183,7,149
6103,FFHHIFRGIVHVGRTIHKLVTGG,8.0,pepID510,5,149
6104,FGPVIGLLSGILKSLL,200.0,pepID568,9,149
...,...,...,...,...,...
10695,TmaTmaTmaPheTmaTmaTmaAegTmaPheTmaTmaTmaTmaTmaT...,128.0,polyID36_S96,8,36
10696,OlamOlamTmaTmaTmaTmaPheOlamTmaTmaTmaTmaTmaPheT...,128.0,polyID36_S97,8,36
10697,TmaTmaTmaTmaTmaOlamTmaTmaTmaTmaTmaTmaTmaTmaTma...,128.0,polyID36_S98,8,36
10698,TmaTmaTmaPheTmaTmaTmaTmaTmaTmaTmaTmaAegTmaTmaT...,128.0,polyID36_S99,8,36


In [282]:
# test set
test_start_idx = val_end_idx
test_set_len = int(np.around(len(df_shuffled)*SPLIT[2]))
test_end_idx = get_idx(test_start_idx, test_set_len) + 1

test_set = df_shuffled.iloc[test_start_idx:test_end_idx,:]
test_set

sequence           QEKPYWPPPIYPM
MIC_ecoli                 9999.0
ID                     pepID9857
MIC_ecoli_class               11
group                        145
Name: 15291, dtype: object
145
15291
[15203, 15204, 15205, 15206, 15207, 15208, 15209, 15210, 15211, 15212, 15213, 15214, 15215, 15216, 15217, 15218, 15219, 15220, 15221, 15222, 15223, 15224, 15225, 15226, 15227, 15228, 15229, 15230, 15231, 15232, 15233, 15234, 15235, 15236, 15237, 15238, 15239, 15240, 15241, 15242, 15243, 15244, 15245, 15246, 15247, 15248, 15249, 15250, 15251, 15252, 15253, 15254, 15255, 15256, 15257, 15258, 15259, 15260, 15261, 15262, 15263, 15264, 15265, 15266, 15267, 15268, 15269, 15270, 15271, 15272, 15273, 15274, 15275, 15276, 15277, 15278, 15279, 15280, 15281, 15282, 15283, 15284, 15285, 15286, 15287, 15288, 15289, 15290, 15291, 15292, 15293, 15294, 15295, 15296, 15297, 15298, 15299, 15300, 15301, 15302]
15302


Unnamed: 0,sequence,MIC_ecoli,ID,MIC_ecoli_class,group
10700,ACYCRIGACVSGERLTGACGLNGRIYRLCCR,14.974747,pepID47,5,126
10701,AGRGKQGGKVRAKAKTRSSRAGLQFPVGRVHRLLRKGNY,1.355447,pepID64,2,126
10702,AIHKLAHKLLKKLLKAVKKLAK,2.236068,pepID83,3,126
10703,FFGSVLKLIPKIL_amd,25.000000,pepID485,6,126
10704,FLKGIVGMLGKLW_amd,3.000000,pepID814,3,126
...,...,...,...,...,...
15298,MNLEVVVQLGSLSLIVLAGPIIVLLLASQKGNL,9999.000000,pepID10711,11,145
15299,QYVKDPDKQVVARIFLDLQLVQR,9999.000000,pepID10774,11,145
15300,LGIYSDAGTQTCSK,9999.000000,pepID10832,11,145
15301,EDAYKFTTW,9999.000000,pepID10897,11,145


In [285]:
 {'train': train_set.set_index('ID')['sequence'].to_dict(),
  'val': val_set.set_index('ID')['sequence'].to_dict(),
  'test': test_set.set_index('ID')['sequence'].to_dict()}

{'train': {'pepID94': 'AIPWSIWWHLLFKG',
  'pepID134': 'AKRLKKLAKKIWKWK_amd',
  'pepID141': 'ALFKTMLKKLGTMAL_amd',
  'pepID160': 'ALRSAVRTVARVGRAVLPHVAI_amd',
  'pepID592': 'FIGLLISAGKAIHDLIRRRH',
  'pepID610': 'FILGKLWKGVKSIF_amd',
  'pepID632': 'FKCRRWQWRMKALGA',
  'pepID686': 'FLFSLIPKAIGGLISAFK_amd',
  'pepID801': 'FLIIRRPIVLGLL',
  'pepID867': 'FLPIIASVAAKVFSKIFCAISKKC',
  'pepID1084': 'FSEAIKKIIDFLGEGLFDIIKKIAESF',
  'pepID1231': 'GFGDSVKEGLKNAAVTILNKIKCKISECPPA',
  'pepID1376': 'GIFSKFVGKGLKNLFMKGAKTIGREVGMDVLRTGIDIAGCKIKGEC',
  'pepID1509': 'GIHHILKYGKPS',
  'pepID1685': 'GKWMHLLKHILK',
  'pepID1705': 'GKWMSLWKHILK_amd',
  'pepID1868': 'GLICESCRKIIQKLEDMVGPQPNEDTVTQAASRVCDKMKILRGVCKKIMRTFLRRISKDILTGKKPQAICVDIKICKE',
  'pepID2144': 'GMASKAGSVLGKVAKVALKAAL_amd',
  'pepID2222': 'GRFRRLGRKFKKLFKKYGP',
  'pepID2250': 'GRPNPVNNKPTSHPRPIRV_amd',
  'pepID2322': 'GVIDAAAKVVNVLKNLF',
  'pepID2577': 'IKHQGLPQE',
  'pepID2857': 'IRMRIRVLL',
  'pepID2975': 'KIAGKIAKKAGKIAK_amd',
  'pepID3020

In [74]:
pep_idx = db['ID'].str.contains('pep')
peptides = db[pep_idx].sample(frac=1, random_state=42).reset_index(drop=True)

# Assign group numbers
db[pep_idx]['group'] = (np.arange(len(peptides)) // k) + start + 1