# Compute tmap projections of generated molecules conditioned on several cluster pairs
- This notebook uses a different environment (`environment_tmap.yml`) since Rdkit needed a older version of pandas not compatible with faerun
- Output figures are provided in html format which migth not be visualized correctly in some IDEs. Read the html file susing your broswer for better visualization

In [1]:
import pickle
import numpy as np
import pandas as pd
import tmap

from tqdm import tqdm
import os
from joblib import Parallel, delayed

import cpmolgan.utils as utils
import cpmolgan.visualization as vi

import tmap_visualization as tv


In [2]:
args = {
    'generated_mols_dir':"results/generated_mols",
    'generated_ref_filename':"CLUSTER__15000_Valid.csv",
    'figures_dir':'figures'
}

# Directory to save computed t-map coordinates 
args['results_dir'] = args['generated_mols_dir'].replace('generated_mols','tmap_projections')
if not os.path.isdir(args['results_dir']):
    os.makedirs(args['results_dir'])

# Directory to save resulting figures
if not os.path.isdir(args["figures_dir"]):
    os.makedirs(args["figures_dir"])

## 1. DMSO close vs DMSO distant
This section generates plots from Figure 2f

#### 1.1 Read and select data and Compute trees

In [3]:
N_per_cluster = 5000
exp_id = 'generated_cluster_DMSO_close_vs_DMSO_distant'
output_file_coord = os.path.join( args['results_dir'], exp_id+'__tmap_coordinates.pkl')
output_file_data = os.path.join( args['results_dir'], exp_id+'__tmap_data_and_labels.csv')

if not os.path.isfile(output_file_coord):
    
    # Read file from all clusters
    gen_clust = pd.DataFrame()
    filenames = [ f for f in os.listdir(args['generated_mols_dir']) if f.endswith('.csv')]
    for filename in filenames:
        gen_clust = pd.concat([gen_clust, pd.read_csv(os.path.join(args['generated_mols_dir'],filename), index_col=0)])
    gen_clust = gen_clust.loc[gen_clust.label=='Cpd'].drop(columns='label')
    gen_clust = gen_clust.drop_duplicates(subset='SMILES_standard')
    gen_clust = gen_clust.groupby(by='cluster').sample(N_per_cluster, random_state=10)

    # Sample a random selection per cluster and asign label columns
    gen_clust['int_cluster'] = gen_clust['cluster'].apply(lambda x: int(x.replace('Cluster','')))
    gen_clust['label'] =  1
    dmso_close_idx = gen_clust.int_cluster <10
    gen_clust.loc[dmso_close_idx, 'label'] = 0
    gen_clust.loc[dmso_close_idx,'label_name'] = 'DMSO-close'
    gen_clust.loc[dmso_close_idx==False,'label_name'] = 'DMSO-distant'
    gen_clust = gen_clust.sort_values(by='label')

    # Remove problematic SMILES (who know why thet are there)
    print("Checking molecules")
    mols = Parallel(n_jobs=16)( delayed(utils.smile_to_mol)(s) for s in tqdm(gen_clust.SMILES_standard))
    invalid_mol_idx = np.array(mols) == None
    gen_clust = gen_clust.loc[  invalid_mol_idx==False ].reset_index(drop=True)
    print('Removed %i invalid mols'%invalid_mol_idx.sum())

    # Compute forest and coordinates 
    print('computing LSH Forest')
    lf_gen_clust = tv.lsh_forest_from_smiles( gen_clust.SMILES_standard )
    print('computing and saving coordinates')
    x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, _ = tmap.layout_from_lsh_forest(lf_gen_clust, tv.config )
    
    # Save results
    tv.save_coordinates( x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, output_file_coord)
    gen_clust.to_csv(output_file_data)
else:
    #lf_gen_clust = tv.load_lsh_forest(output_file)
    x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust = tv.load_coordinates(output_file_coord)
    gen_clust = pd.read_csv(output_file_data,index_col=0)
    print('Loaded coordinates')


Loaded coordinates


#### 1.2 Plot trees

In [4]:
output_filename = os.path.join(args['figures_dir'], ' tmap_DMSO_close_vs_DMSO_distant')
colors = [ [0,0,0,1], [0.7,0.7,0.7]]

tv.plot_tree(gen_clust, x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, 
             html_output_file=output_filename, colors=colors, plot_name='generated', max_point_size=3)

<matplotlib.colors.ListedColormap object at 0x7ff97a3b8be0>


## 2. Cluster_subsets
This section generates plots from Figure 2d-e

#### 2.1 Select a set of clusters 

In [5]:
# List of close clusters from Figure 2d (uncomment for picking one)
#selected_clusters_int = [ 0,2]
#selected_clusters_int = [ 3,4]
selected_clusters_int = [ 5,7]
#selected_clusters_int = [ 9,11]
#selected_clusters_int = [ 8,13]

# List of far clusters from Figure 2e (uncomment for picking one)
#selected_clusters_int = [ 0,7]
#selected_clusters_int = [ 6,9]
#selected_clusters_int = [ 4,10]
#selected_clusters_int = [ 1,8]
#selected_clusters_int = [ 3,19]

sorted_idx = np.argsort( selected_clusters_int )
selected_clusters = np.array([ 'Cluster'+str(c) for c in selected_clusters_int ])[sorted_idx]

#### 2.2 Read data and Compute trees

In [6]:
N_per_cluster = 5000
exp_id = 'generated_cluster_'+'_'.join(selected_clusters)
output_file_coord =  os.path.join( args['results_dir'],exp_id+'__tmap_coordinates.pkl')
output_file_data =  os.path.join( args['results_dir'],exp_id+'__tmap_data_and_labels.csv')

if not os.path.isfile(output_file_coord):
    
    # Read files from all selected clusters
    gen_clust = pd.DataFrame()
    for cluster in selected_clusters:
        filename = os.path.join(args['generated_mols_dir'], args['generated_ref_filename'].replace('CLUSTER',cluster))
        gen_clust = pd.concat([gen_clust, pd.read_csv(filename, index_col=0)])
    gen_clust = gen_clust.loc[gen_clust.label=='Cpd'].drop(columns='label')
    gen_clust = gen_clust.drop_duplicates(subset='SMILES_standard')
    gen_clust = gen_clust.groupby(by='cluster').sample(N_per_cluster, random_state=10)
    gen_clust = gen_clust.loc[ gen_clust.cluster.isin(selected_clusters)]

    # Sample a random selection per cluster and asign label columns
    sequential_cluster_dict = dict( zip(selected_clusters,range(len(selected_clusters))) )
    gen_clust['label_name'] =  gen_clust.cluster
    gen_clust['label'] = gen_clust['label_name'].map(sequential_cluster_dict)
    gen_clust = gen_clust.sort_values(by='label')

    # Remove problematic SMILES (who know why thet are there)
    print("Checking molecules")
    mols = Parallel(n_jobs=16)( delayed(utils.smile_to_mol)(s) for s in tqdm(gen_clust.SMILES_standard))
    invalid_mol_idx = np.array(mols) == None
    gen_clust = gen_clust.loc[  invalid_mol_idx==False ].reset_index(drop=True)
    print('Removed %i invalid mols'%invalid_mol_idx.sum())

    # Compute forest and coordinates 
    print('computing LSH Forest')
    lf_gen_clust = tv.lsh_forest_from_smiles( gen_clust.SMILES_standard )
    print('computing and saving coordinates')
    x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, _ = tmap.layout_from_lsh_forest(lf_gen_clust, tv.config )
    
    # Save results
    tv.save_coordinates( x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, output_file_coord)
    gen_clust.to_csv(output_file_data)
else:
    #lf_gen_clust = tv.load_lsh_forest(output_file)
    x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust = tv.load_coordinates(output_file_coord)
    gen_clust = pd.read_csv(output_file_data,index_col=0)
    print('Loaded coordinates')


Loaded coordinates


#### 2.3 Display tree

In [7]:
output_filename = os.path.join(args['figures_dir'], ' tmap_'+exp_id)
colors = [ vi.cluster_colors_dict[c] for c in selected_clusters]

tv.plot_tree(gen_clust, x_gen_clust, y_gen_clust, s_gen_clust, t_gen_clust, 
             html_output_file=output_filename, colors=colors, plot_name='generated', max_point_size=4)

<matplotlib.colors.ListedColormap object at 0x7ff974d6bf98>
