# _-----------Data curation--------------------_

In [None]:
# Imports


from datasets import load_dataset

from rdkit import DataStructs, Chem
from rdkit.Chem import Draw, AllChem


import pickle
import os

from tqdm import tqdm

import networkx as nx


import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from data_curation_augmentation_splitting_functions import (substructure_split_sort, 
                                                            prepare_data_set, 
                                                            generate_murcko_in_df, 
                                                            generate_anonymous_murcko_in_df, 
                                                            standardize_smiles, 
                                                            create_group_index_mapping, 
                                                            collect_unique_substructures,
                                                            get_boundary_bondtype,
                                                            remove_dummy_atom_from_mol,
                                                            remove_dummy_atom_from_smiles,
                                                            remove_stereo,
                                                            reassemble_protac,
                                                            get_murcko,
                                                            get_anonymous_murcko,
                                                            make_graph_with_pos,
                                                            get_test_trainval_smi,
                                                            count_unique_SMILES_and_MurckoScaffolds,
                                                            validate_test_set,
                                                            add_attachments,
                                                            validate_no_general_ms_leakage,
                                                            generate_protacs_indices,
                                                            get_clusters_substructures_attachments_dict,
                                                            generate_protac_from_indices_list,
                                                            align_mol_2D_ver2,
                                                            butina_clustering_substructures_with_fixed_cutoff,
                                                            barplot_from_protac_substructureindices_list,
                                                            get_bond_idx,
                                                            generate_splits)


from data_curation_augmentation_splitting_functions import (compute_FP_substructures, 
                                                            compute_countMorgFP,
                                                            compute_RDKitFP)


from data_curation_augmentation_splitting_functions import (save_as_svg,
                                                            align_molecules_2D,
                                                            align_molecules_by_coordinates,
                                                            draw_molecule_with_highlighted_bonds,
                                                            tailored_framework_example,
                                                            transform_molecule,
                                                            draw_molecule_to_svg,
                                                            combine_svgs)





# Datasets


## Download data

In [None]:
train_dataset = pd.read_csv('/../data/train_dataset.csv')
validation_dataset = pd.read_csv('/../data/validation_dataset.csv')

trainset_protac = train_dataset['smiles']
trainset_substructures = train_dataset['substructures']

validationset_protac = validation_dataset['smiles']
validationset_substructures = validation_dataset['substructures']


#cwd = os.getcwd()
#data_path = cwd + '/../data/curated_dataset.csv'

#your_huggingface_token = "token" #this source may be set to private. In this case, load the "curated_dataset.csv" instead.

#if os.path.isfile(data_path):
#    with open(data_path, "rb") as file:
#        dataset_aibe = pickle.load(file)
#else:
#    name_dataset_aibe = "80-20-split"
#    dataset_aibe = load_dataset("ailab-bio/PROTAC-Substructures", name_dataset_aibe, token=your_huggingface_token) #download use_auth_token=True

#train_dataset = dataset_aibe['train']
#validation_dataset = dataset_aibe['validation']

#trainset_protac = train_dataset['text']
#trainset_substructures = train_dataset['labels']

#validationset_protac = validation_dataset['text']
#validationset_substructures = validation_dataset['labels']

## Load data & remove ambigous substructure matches (matching can be improved)

In [None]:
#Merge training and validation data, and create dataframe of public PROTACs, POI, Linker, E3.

protac_smi_list = trainset_protac + validationset_protac
substructures_list = trainset_substructures + validationset_substructures

poi_smi_r_list = []
linker_smi_r_list = []
e3_smi_r_list = []
for substructures in substructures_list:
    poi_smi_r, linker_smi_r, e3_smi_r = substructure_split_sort(substructures)
    poi_smi_r_list.append(poi_smi_r)
    linker_smi_r_list.append(linker_smi_r)
    e3_smi_r_list.append(e3_smi_r)

pub_smi_df = pd.DataFrame({
    'Smiles': protac_smi_list,
    'POI': poi_smi_r_list,
    'Linker': linker_smi_r_list,
    'E3': e3_smi_r_list
})

In [None]:
#Removes protacs with ambigous substructure matches. Split dataframe by PROTACs and substructures
pub_protac_smi_prep_df, pub_substructure_unsplitted_smi_prep_df = prepare_data_set(pub_smi_df, 'Smiles', 'POI', 'Linker', 'E3')

In [None]:
#Merge dataframes again = Create dataframe of public PROTACs, POI, Linker, E3, where there are no ambigous substructure matches

protac_smi_prep_list = []
poi_smi_r_prep_list = []
linker_smi_r_prep_list = []
e3_smi_r_prep_list = []
for idx, row in pub_substructure_unsplitted_smi_prep_df.iterrows():
    substructures = row["substructures"]
    protac_smi_prep = pub_protac_smi_prep_df.loc[idx,"Smiles"]
    protac_smi_prep_list.append(protac_smi_prep)
    poi_smi_prep_r, linker_smi_prep_r, e3_smi_prep_r = substructure_split_sort(substructures)
    poi_smi_r_prep_list.append(poi_smi_prep_r)
    linker_smi_r_prep_list.append(linker_smi_prep_r)
    e3_smi_r_prep_list.append(e3_smi_prep_r)

pub_smi_prep_df = pd.DataFrame({
    'Smiles': protac_smi_prep_list,
    'POI': poi_smi_r_prep_list,
    'Linker': linker_smi_r_prep_list,
    'E3': e3_smi_r_prep_list
})



# Generate Framework SMILES

In [None]:
#Get murcko scaffold & framework
for col in ["Smiles", "POI", "Linker", "E3"]:
    pub_smi_prep_df = generate_murcko_in_df(pub_smi_prep_df, col)
    pub_smi_prep_df = generate_anonymous_murcko_in_df(pub_smi_prep_df, col)

#Ensure SMILES are standardized
all_structures = ["Smiles", "POI", "Linker", "E3", "Smiles_MS", "Smiles_AnonMS", "POI_MS", "POI_AnonMS", "Linker_MS", "Linker_AnonMS", "E3_MS", "E3_AnonMS"]
for col in all_structures:
    pub_smi_prep_df[col] =  pub_smi_prep_df[col].apply(standardize_smiles)
    

In [None]:
# Extract unique standardized FRAMEWORK SMILES
unique_abstract_poi_smi_list = pub_smi_prep_df['POI_AnonMS'].dropna().unique().tolist()
unique_abstract_linker_smi_list = pub_smi_prep_df['Linker_AnonMS'].dropna().unique().tolist()
unique_abstract_e3_smi_list = pub_smi_prep_df['E3_AnonMS'].dropna().unique().tolist()

# Create groups for POI, Linker, and E3 for those which have the same abstract SMILES
poi_AnonMS_to_group = create_group_index_mapping(unique_abstract_poi_smi_list)
linker_AnonMS_to_group = create_group_index_mapping(unique_abstract_linker_smi_list)         #Dictionary mapping a framework to it respective ID
e3_AnonMS_to_group = create_group_index_mapping(unique_abstract_e3_smi_list)

# Assign group indices in dataframe
pub_smi_prep_df['POI_Group'] = pub_smi_prep_df['POI_AnonMS'].map(poi_AnonMS_to_group)
pub_smi_prep_df['Linker_Group'] = pub_smi_prep_df['Linker_AnonMS'].map(linker_AnonMS_to_group)
pub_smi_prep_df['E3_Group'] = pub_smi_prep_df['E3_AnonMS'].map(e3_AnonMS_to_group)

# Make dictionaries containing groups for substructures for POIs, linkers, and E3s which have the same abstract SMILES
poi_AnonMSgroup_to_smi = collect_unique_substructures(pub_smi_prep_df, 'POI_Group', 'POI')
linker_AnonMSgroup_to_smi = collect_unique_substructures(pub_smi_prep_df, 'Linker_Group', 'Linker')
e3_AnonMSgroup_to_smi = collect_unique_substructures(pub_smi_prep_df, 'E3_Group', 'E3')

In [None]:
#assign an ID for each substructure SMILES
POI_ids, _ = pd.factorize(pub_smi_prep_df['POI'])
pub_smi_prep_df["POI_ID"] = POI_ids
Linker_ids, _ = pd.factorize(pub_smi_prep_df['Linker'])
pub_smi_prep_df["Linker_ID"] = Linker_ids
E3_ids, _ = pd.factorize(pub_smi_prep_df['E3'])
pub_smi_prep_df["E3_ID"] = E3_ids


In [None]:
#for each unique smile, get the id, go through all rows in the dataframe with the same id, validate that the unique smile matches these smiles in the dataframe for the same ids
#If you get no errors, everything is fine

unique_poi = pub_smi_prep_df['POI'].dropna().unique().tolist()
unique_linker = pub_smi_prep_df['Linker'].dropna().unique().tolist()
unique_e3 = pub_smi_prep_df['E3'].dropna().unique().tolist()

for s, unique_s in zip(["POI", "Linker", "E3"], [unique_poi, unique_linker, unique_e3]):
    for unique_id, smi_unique in enumerate(unique_s):
        condition = pub_smi_prep_df[f"{s}_ID"] == unique_id
        smi = pub_smi_prep_df.loc[condition,f"{s}"].dropna().unique().tolist() 
        if len(smi)>1:
            raise ValueError("not unique")
        if smi_unique != smi[0]: 
            raise ValueError("not matching")


    

# Move E3 in POI list to E3 list

In [None]:
#Compute fingerprint for all substructures
for col in ["POI", "Linker" , "E3", "POI_MS", "Linker_MS", "E3_MS"]:
    fps = compute_FP_substructures(pub_smi_prep_df, [col], fp_function=compute_countMorgFP, return_unique=False, convert_to_numpyarray = True)
    pub_smi_prep_df[f"{col}_MorgFP"] = fps[0]


In [None]:
poi_unique_fps = compute_FP_substructures(pub_smi_prep_df, ["POI"], fp_function=compute_countMorgFP, return_unique=True, convert_to_numpyarray = False)
poi_unique_fps = poi_unique_fps[0]
e3_unique_fps = compute_FP_substructures(pub_smi_prep_df, ["E3"], fp_function=compute_countMorgFP, return_unique=True, convert_to_numpyarray = False)
e3_unique_fps = e3_unique_fps[0]

similarity_matrix_poi_e3 = np.zeros((len(poi_unique_fps), len(e3_unique_fps)))
for i in range(len(poi_unique_fps)):
    for j in range(0, len(e3_unique_fps)):
        similarity = DataStructs.TanimotoSimilarity(poi_unique_fps[i], e3_unique_fps[j])
        similarity_matrix_poi_e3[i][j] = similarity

row_max_values = np.max(similarity_matrix_poi_e3, axis=1) # Maximum values in each row
row_max_column_indices = np.argmax(similarity_matrix_poi_e3, axis=1) # Column indices for max values in each row



poi_fps = compute_FP_substructures(pub_smi_prep_df, ["POI"], fp_function=compute_countMorgFP, return_unique=False, convert_to_numpyarray = False)
poi_fps = poi_fps[0]
e3_fps = compute_FP_substructures(pub_smi_prep_df, ["E3"], fp_function=compute_countMorgFP, return_unique=False, convert_to_numpyarray = False)
e3_fps = e3_fps[0]

similarity_matrix_poi_poi = np.zeros((len(poi_fps), len(poi_fps)))
for i in range(len(poi_fps)):
    for j in range(0, len(poi_fps)):
        #print(poi_unique_fps[i])
        similarity = DataStructs.TanimotoSimilarity(poi_fps[i], poi_fps[j])
        similarity_matrix_poi_poi[i][j] = similarity

similarity_matrix_e3_e3 = np.zeros((len(e3_fps), len(e3_fps)))
for i in range(len(e3_fps)):
    for j in range(0, len(e3_fps)):
        #print(poi_unique_fps[i])
        similarity = DataStructs.TanimotoSimilarity(e3_fps[i], e3_fps[j])
        similarity_matrix_e3_e3[i][j] = similarity



In [None]:
#Plot suspected E3 in POI list, which matches well to an E3 in the E3 list.                           OBS! This does not do anything, its just for visualizing!



# Sort rows by their max values and get the top 10
sorted_row_indices_top10 = np.argsort(row_max_values)[::-1][:10] # Indices of rows sorted by max value

show_structures = True
# Display top rows, their max values, and the column indices for these max values
if show_structures:
    #print("\nTop Rows:")
    for idx in sorted_row_indices_top10:
        print(f"E3s in POI list:   POI Index: {idx}, Max Value: {row_max_values[idx]}, E3 Index: {row_max_column_indices[idx]}")
        
        
        print(f'The suspected E3 occurs {sum(pub_smi_prep_df["POI_ID"]==idx)} times among the POI in all PROTACs, whereas the E3 it matched best against occurs {sum(pub_smi_prep_df["E3_ID"]==row_max_column_indices[idx])} times among the E3 in all PROTACs')
        poi_group_of_poi_id = pub_smi_prep_df[pub_smi_prep_df["POI_ID"]==idx]['POI_Group'].unique().tolist()[0]
        poi_framework_smi = pub_smi_prep_df[pub_smi_prep_df["POI_Group"]==poi_group_of_poi_id]["POI_AnonMS"].tolist()[0]
        e3_group_of_e3_id = pub_smi_prep_df[pub_smi_prep_df["E3_ID"]==row_max_column_indices[idx]]['E3_Group'].unique().tolist()[0]
        e3_framework_smi = pub_smi_prep_df[pub_smi_prep_df["E3_Group"]==e3_group_of_e3_id]["E3_AnonMS"].tolist()[0]
        if  row_max_values[idx]> 0.4 and e3_framework_smi == poi_framework_smi:
            print("E3 and POI (suspected E3) have the same graph framework:")
            print("POI:")
            poi_smi = unique_poi[idx]
            display(Chem.MolFromSmiles(poi_smi))
            print(poi_smi)
            print("E3:")
            e3_smi = unique_e3[row_max_column_indices[idx]]
            display(Chem.MolFromSmiles(e3_smi))
            print(e3_smi)


            print("\n")

In [None]:
#Moves E3 from POI list to E3 list, if above cutoff and same framework.       If this is run multiple times, it will fail (as is). Run only once.

save_fig = False

e3_among_poi = []
poi_among_e3 = []
sorted_row_indices = np.argsort(row_max_values)[::-1]
pairing_cutoff = 0.4 # low as to find all matches between the lists of E3 and Warheads
move_id = 0
for poi_idx in sorted_row_indices: #Go thorugh all smi in the POI list
    tanimoto_similatity = row_max_values[poi_idx]
    if  tanimoto_similatity > pairing_cutoff:  #If any smi in POI list matches to any E3 about the cutoff
        e3_idx = row_max_column_indices[poi_idx]

        poi_group_of_poi_id = pub_smi_prep_df[pub_smi_prep_df["POI_ID"]==poi_idx]['POI_Group'].tolist()[0]
        poi_framework_smi = pub_smi_prep_df[pub_smi_prep_df["POI_Group"]==poi_group_of_poi_id]["POI_AnonMS"].tolist()[0]

        e3_group_of_e3_id = pub_smi_prep_df[pub_smi_prep_df["E3_ID"]==e3_idx]['E3_Group'].tolist()[0]
        e3_framework_smi = pub_smi_prep_df[pub_smi_prep_df["E3_Group"]==e3_group_of_e3_id]["E3_AnonMS"].tolist()[0]

        poi_smi = unique_poi[poi_idx]
        e3_smi = unique_e3[e3_idx]
        
        poi_mol = Chem.MolFromSmiles(poi_smi)
        e3_mol = Chem.MolFromSmiles(e3_smi)

        framework_str = "e3_(maybe)_among_poi_different_frameworks"
        if e3_framework_smi == poi_framework_smi:
            framework_str = "e3_among_poi_same_framework"

        print(f"Pair of Warhead and E3 matches to {tanimoto_similatity} (above {pairing_cutoff}) {framework_str}")
        print("POI")
        display(poi_mol)
        print("E3")
        display(e3_mol)


        if e3_framework_smi == poi_framework_smi: #matches_of_e3_to_other_e3 > matches_of_poi_to_other_poi:             #the pair are both most likely E3's as the E3 in the pair matches more often to other E3s than the 'Warhead' matches more to other warheads
            e3_among_poi.append(poi_idx)
        svg_name = f'{framework_str}_{move_id}_sim{round(tanimoto_similatity,4)}'
        svg_path = f"{os.getcwd()}/fig_method/{svg_name}.svg"
        e3_mol = align_molecules_2D(poi_mol, e3_mol)
        substructures_img = Draw.MolsToGridImage([poi_mol, e3_mol], subImgSize=(500, 500), useSVG=True)
        move_id += 1

        if save_fig:
            save_as_svg(substructures_img, svg_path, num_mols=2)
            display(SVG(filename=svg_path))
        
        
        print("\n")


In [None]:

e3_among_poi_smi = [unique_poi[idx] for idx in e3_among_poi]
unique_poi = list(set(unique_poi)-set(e3_among_poi_smi))

poi_among_e3_smi = [unique_e3[idx] for idx in poi_among_e3]
unique_e3 = list(set(unique_e3)-set(poi_among_e3_smi))

#Change attatchment point from [*:1] to [*:2] before transferring to E3-list
e3_among_poi_smi = [smi.replace("[*:1]", "[*:2]") for smi in e3_among_poi_smi]
poi_among_e3_smi = [smi.replace("[*:2]", "[*:1]") for smi in poi_among_e3_smi]

unique_e3 = list(set(unique_e3 + e3_among_poi_smi)) #Add E3s from POI list. Make sure there are no duplicates
unique_poi = list(set(unique_poi + poi_among_e3_smi)) #Add E3s from POI list. Make sure there are no duplicates

#In the future: only delete from unique_poi, or modify the dataframe (maybe replace these "POI" with a string saying 'E3'?) - Do not forget. This means to not extract POI from the dataframe directly
    

# (ETC) Analyze distribution of boundary bonds

In [None]:
unique_substructures_with_attachments = {'POI': unique_poi, 'LINKER': unique_linker, 'E3': unique_e3}
bondtype_count_substructures =  {}
for substruct_str in unique_substructures_with_attachments.keys():
    bondtype_count = {'SINGLE': 0, 'DOUBLE': 0, 'TRIPLE': 0}
    substructure_plural_smi_with_attachment = unique_substructures_with_attachments[substruct_str]
    for substructure_smi_with_attachment in substructure_plural_smi_with_attachment:
        substruct_mol_with_attachment = Chem.MolFromSmiles(substructure_smi_with_attachment)
        bondtype_count = get_boundary_bondtype(mol=substruct_mol_with_attachment, bondtype_count=bondtype_count)
    bondtype_count_substructures[substruct_str] = bondtype_count

print("Count of each bond type for all unique substructures, grouped by substructure type:")
print(bondtype_count_substructures)

# Get substructures without attachmentpoints

In [None]:
unique_substructures_with_attachments = {'POI': unique_poi, 'LINKER': unique_linker, 'E3': unique_e3}
unique_substructures_without_attachments = {'POI': [], 'LINKER': [], 'E3': []}

unique_substructures_without_attachments_dict_to_with_attachments = {'POI': {}, 'LINKER': {}, 'E3': {}}



for substruct in unique_substructures_without_attachments.keys():



    substructures_smi_with_attachment = unique_substructures_with_attachments[substruct]
    for substructure_smi_with_attachment in substructures_smi_with_attachment:
        
        substruct_mol_with_attachment = Chem.MolFromSmiles(substructure_smi_with_attachment)
        substruct_mol_without_attachment = remove_dummy_atom_from_mol(mol=substruct_mol_with_attachment, output = "smiles")
        unique_substructures_without_attachments[substruct].append(substruct_mol_without_attachment)

        if substruct_mol_without_attachment not in unique_substructures_without_attachments_dict_to_with_attachments[substruct]:
            unique_substructures_without_attachments_dict_to_with_attachments[substruct][substruct_mol_without_attachment] = []
        unique_substructures_without_attachments_dict_to_with_attachments[substruct][substruct_mol_without_attachment].append(substructure_smi_with_attachment)

unique_poi_without_attachment = list(set(unique_substructures_without_attachments['POI']))
unique_linker_without_attachment = list(set(unique_substructures_without_attachments['LINKER']))
unique_e3_without_attachment = list(set(unique_substructures_without_attachments['E3']))

# _-----------Data split------------------------_

# Datasplit via Tanimoto similarity (HDBSCAN)

In [None]:
poi_without_attachment_test_trainval_smi_dict = get_test_trainval_smi(smi_list = unique_poi_without_attachment, 
                                                                        fp_function = compute_countMorgFP,
                                                                        max_allowed_tanimoto_similarity = 0.45, 
                                                                        test_set_minfraction = 0.15, 
                                                                        test_set_maxfraction = 0.30, 
                                                                        binwidth_plural = [5, 1],
                                                                        substructure_type='poi')

poi_test_smi_set_without_attachment_splits = poi_without_attachment_test_trainval_smi_dict["test"]
poi_trainval_smi_set_without_attachment = poi_without_attachment_test_trainval_smi_dict["trainval"]

for split_idx, test_split in poi_test_smi_set_without_attachment_splits.items():
    print(f' Test counts of POI, without attachment, split idx {split_idx}: {count_unique_SMILES_and_MurckoScaffolds(test_split)}')
print(f' Trainval counts of POI: {count_unique_SMILES_and_MurckoScaffolds(poi_trainval_smi_set_without_attachment)}')

print("\n ----------------------------------------------------------------- \n")


In [None]:


#remove "linker" with no size, which directly joins POI and E3
while "[*:1][*:2]" in unique_linker:
    unique_linker.remove("[*:1][*:2]")
while "[*:2][*:1]" in unique_linker:
    unique_linker.remove("[*:2][*:1]")

while "[H][H]" in unique_linker_without_attachment:
    unique_linker_without_attachment.remove("[H][H]")

linker_without_attachment_test_trainval_smi_dict = get_test_trainval_smi(smi_list = unique_linker_without_attachment, 
                                                                        fp_function = compute_RDKitFP,
                                                                        max_allowed_tanimoto_similarity = 0.45, 
                                                                        test_set_minfraction = 0.15, 
                                                                        test_set_maxfraction = 0.30, 
                                                                        binwidth_plural = [5, 1],
                                                                        substructure_type='linker')

linker_test_smi_set_without_attachment_splits = linker_without_attachment_test_trainval_smi_dict["test"]
linker_trainval_smi_set_without_attachment = linker_without_attachment_test_trainval_smi_dict["trainval"]

for split_idx, test_split in linker_test_smi_set_without_attachment_splits.items():
    print(f' Test counts of linker, without attachment, split idx {split_idx}: {count_unique_SMILES_and_MurckoScaffolds(test_split)}')
print(f' Trainval counts of linker: {count_unique_SMILES_and_MurckoScaffolds(linker_trainval_smi_set_without_attachment)}')

print("\n ----------------------------------------------------------------- \n")



In [None]:

e3_without_attachment_test_trainval_smi_dict = get_test_trainval_smi(smi_list = unique_e3_without_attachment, 
                                                                        fp_function = compute_countMorgFP,
                                                                        max_allowed_tanimoto_similarity = 0.5, 
                                                                        test_set_minfraction = 0.2, 
                                                                        test_set_maxfraction = 0.30, 
                                                                        binwidth_plural = [5, 1],
                                                                        substructure_type='e3')

e3_test_smi_set_without_attachment_splits = e3_without_attachment_test_trainval_smi_dict["test"]
e3_trainval_smi_set_without_attachment = e3_without_attachment_test_trainval_smi_dict["trainval"]

for split_idx, test_split in e3_test_smi_set_without_attachment_splits.items():
    print(f' Test counts of e3, without attachment, split idx {split_idx}: {count_unique_SMILES_and_MurckoScaffolds(test_split)}')
print(f' Trainval counts of e3: {count_unique_SMILES_and_MurckoScaffolds(e3_trainval_smi_set_without_attachment)}')

# Generate generate recombined PROTACs and test splits

In [None]:
fixed_cutoff = 0.33
split_dict = generate_splits(poi_test_smi_set_without_attachment_splits, linker_test_smi_set_without_attachment_splits, e3_test_smi_set_without_attachment_splits,
                    poi_trainval_smi_set_without_attachment, linker_trainval_smi_set_without_attachment, e3_trainval_smi_set_without_attachment,
                    unique_substructures_without_attachments_dict_to_with_attachments, unique_substructures_with_attachments, fixed_cutoff)

butina_cluster_cutoff = f'ButinaClusterCutoff_{fixed_cutoff}'
for split_idx, datasets in split_dict.items():
    if "Train" in datasets:
        train_df = datasets["Train"]
        train_df.to_csv(f'../data/augmented/train_{len(train_df)}_{butina_cluster_cutoff}.csv', index=False)
    
    if "Validation" in datasets:
        val_df = datasets["Validation"]
        val_df.to_csv(f'../data/augmented/val_{len(val_df)}_{butina_cluster_cutoff}.csv', index=False)
    
    
    
    test_poi_df = datasets["Test POI"]
    test_linker_df = datasets["Test Linker"]
    test_e3_df = datasets["Test E3"]
    test_poilinker_df = datasets["Test POILINKER"]
    test_poie3_df = datasets["Test POIE3"]
    test_e3linker_df = datasets["Test E3Linker"]
    test_protac_df = datasets["Test PROTAC"]


    test_poi_df.to_csv(f'../data/augmented/test_poi_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_linker_df.to_csv(f'../data/augmented/test_linker_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_e3_df.to_csv(f'../data/augmented/test_e3_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_poilinker_df.to_csv(f'../data/augmented/test_poilinker_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_poie3_df.to_csv(f'../data/augmented/test_poie3_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_e3linker_df.to_csv(f'../data/augmented/test_e3linker_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)
    test_protac_df.to_csv(f'../data/augmented/test_protac_split{split_idx}_{butina_cluster_cutoff}.csv', index=False)


## Save various train & val dataset sizes

In [None]:
#trainval_train_to_val_fraction = 1/trainval_fraction - 1

generate_larger_trainsizes = True

if generate_larger_trainsizes:
    trainval_smi_set_without_attachment = [poi_trainval_smi_set_without_attachment, linker_trainval_smi_set_without_attachment, e3_trainval_smi_set_without_attachment]

    

    
    trainval_clusters_for_substructures = butina_clustering_substructures_with_fixed_cutoff(
                                                                smi_sets_without_attachment = trainval_smi_set_without_attachment, 
                                                                cutoff = fixed_cutoff,
                                                                plot_top_clusters = [], 
                                                                plot = True,
                                                                yscale='linear',
                                                                test_or_trainval = "trainval")

    trainval_clusters_substructures_attachments_dict = get_clusters_substructures_attachments_dict(trainval_clusters_for_substructures,
                                                                                                    trainval_smi_set_without_attachment,
                                                                                                    unique_substructures_without_attachments_dict_to_with_attachments)

    trainval_fraction = 0.2 
    for num_train_protacs_factor in [3, 10]:

        num_trainval_protacs = num_train_protacs_factor/(1-trainval_fraction)
        num_protacs_factor = num_trainval_protacs*(1-trainval_fraction) + num_trainval_protacs*trainval_fraction # the number new augmented unique protacs equal to num_protacs_factor multiplied by the number of substructures which there are most
        trainval_augmented_protac_substructureindices_list, trainval_clusters_substructures_attachments_dict_out = generate_protacs_indices(clusters_dict=trainval_clusters_substructures_attachments_dict, 
                                                                                                                    num_protacs_factor=num_protacs_factor,
                                                                                                                    force_use_all_attachment_points=True)

        bond_type = 'rand_uniform'
        trainval_smiles_with_attachments = generate_protac_from_indices_list(trainval_clusters_substructures_attachments_dict_out,
                                                                            trainval_augmented_protac_substructureindices_list,
                                                                            bond_type = bond_type)
        
        trainval_df = pd.DataFrame(trainval_smiles_with_attachments) #spit df into training and validation

        train_df = trainval_df.sample(frac=1-trainval_fraction)
        val_df = trainval_df.drop(train_df.index)

        
        butina_cluster_cutoff = f'ButinaClusterCutoff_{fixed_cutoff}'

        train_df.to_csv(f'../data/augmented/train_{len(train_df)}_{butina_cluster_cutoff}.csv', index=False)
        val_df.to_csv(f'../data/augmented/val_{len(val_df)}_{butina_cluster_cutoff}.csv', index=False)




# PROTACs recombined from ALL substructures

For training on all available data, to get a better model

In [None]:
#trainval_train_to_val_fraction = 1/trainval_fraction - 1

generate_larger_trainsizes = True

if generate_larger_trainsizes:
    trainval_smi_set_without_attachment = [list(set(substructures)) for substructures in list(unique_substructures_without_attachments.values())] #
    
    while "[H][H]" in trainval_smi_set_without_attachment[1]:
        trainval_smi_set_without_attachment[1].remove("[H][H]")
    

    
    trainval_clusters_for_substructures = butina_clustering_substructures_with_fixed_cutoff(
                                                                smi_sets_without_attachment = trainval_smi_set_without_attachment, 
                                                                cutoff = fixed_cutoff,
                                                                plot_top_clusters = [], 
                                                                plot = True,
                                                                yscale='linear',
                                                                test_or_trainval = "trainval")

    trainval_clusters_substructures_attachments_dict = get_clusters_substructures_attachments_dict(trainval_clusters_for_substructures,
                                                                                                    trainval_smi_set_without_attachment,
                                                                                                    unique_substructures_without_attachments_dict_to_with_attachments)

    trainval_fraction = 0.2 
    for num_train_protacs_factor in [10]:

        num_trainval_protacs = num_train_protacs_factor/(1-trainval_fraction)
        num_protacs_factor = num_trainval_protacs*(1-trainval_fraction) + num_trainval_protacs*trainval_fraction # the number new augmented unique protacs equal to num_protacs_factor multiplied by the number of substructures which there are most
        trainval_augmented_protac_substructureindices_list, trainval_clusters_substructures_attachments_dict_out = generate_protacs_indices(clusters_dict=trainval_clusters_substructures_attachments_dict, 
                                                                                                                    num_protacs_factor=num_protacs_factor,
                                                                                                                    force_use_all_attachment_points=True)

        bond_type = 'rand_uniform'
        trainval_smiles_with_attachments = generate_protac_from_indices_list(trainval_clusters_substructures_attachments_dict_out,
                                                                            trainval_augmented_protac_substructureindices_list,
                                                                            bond_type = bond_type)
        
        trainval_df = pd.DataFrame(trainval_smiles_with_attachments) #spit df into training and validation

        train_df = trainval_df.sample(frac=1-trainval_fraction)
        val_df = trainval_df.drop(train_df.index)

        
        butina_cluster_cutoff = f'ButinaClusterCutoff_{fixed_cutoff}'

        train_df.to_csv(f'../data/augmented/train_all_{len(train_df)}_{butina_cluster_cutoff}.csv', index=False)
        val_df.to_csv(f'../data/augmented/val_all_{len(val_df)}_{butina_cluster_cutoff}.csv', index=False)


