# Molecular generation conditioned on over-expression morphological profiles of selected genes

In [7]:
import os
import numpy as np
import pandas as pd
import pickle

import tensorflow as tf
tf.compat.v1.logging.set_verbosity (tf.compat.v1.logging.ERROR)

import logging
logging.basicConfig(level=logging.INFO, format ='%(levelname)s - %(message)s')
tf.logging.set_verbosity(tf.logging.ERROR)

from sklearn.preprocessing import QuantileTransformer
from rdkit import Chem
from rdkit.Chem.Descriptors import qed
import sys
sys.path.append(os.path.join(Chem.RDConfig.RDContribDir, 'SA_Score'))
import sascorer

import cpmolgan.utils as utils
import cpmolgan.inference as infr
import pkg_resources
WEIGHTS_PATH = pkg_resources.resource_filename('cpmolgan','model_weights')


### Arguments

In [8]:
# Pick one repetition. This will define output naming and random seed
repetition = ['n1','n2','n3'][0]

args = {
    'use_gpu': True,
    'gpu_device':'2',
    "PassPhysChemFilter": True,
    "PhysChemFilter_alerts_file": "../../data/sure_chembl_alerts.txt",
    "N_valid_per_gene":20000,
    'filename_oe_profiles':'../../data/test_set_overexpression_normalized_profiles.csv',
    "output_dir":f"results/{repetition}_FILTERS/generated_mols",
}

# Define random seed
seed_dict = {"n1":0, "n2":100, "n3":200}
seed_addition = seed_dict[repetition] 

# Automatic naming according to inputs
filters_str = "Valid"
if args["PassPhysChemFilter"]: filters_str = filters_str+"_PassPhysChemFilter"
args["output_dir"] = args["output_dir"].replace("FILTERS", filters_str)

# List of genes
top_10_diff_genes = ['RAF1', 'JUN', 'ATF4', 'BRAF', 'CEBPA', 'RELB', 'MEK1', 'PIK3CD','AKT3', 'WWTR1']
excape_genes = ["TP53","BRCA1","NFKB1","HSPA5", "CREBBP", "STAT1", "STAT3","HIF1A", "NFKBIA","JUN","PRKAA1","PDPK1"]
args["selected_genes"] = excape_genes+top_10_diff_genes 

if not os.path.isdir(args["output_dir"]):
    os.makedirs(args["output_dir"])


## 1. Set compute environment 

In [9]:
if args['use_gpu']:
    os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu_device']
    gpu_options = tf.GPUOptions(visible_device_list='0')
    tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    tf.config.set_soft_device_placement(True)
    tf.debugging.set_log_device_placement(True)
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

## 2. Load inference model

In [10]:
model_weigth_paths = {
    'autoencoder': os.path.join(WEIGHTS_PATH,'autoencoder.h5'),
    'wgan':{
            'C': os.path.join(WEIGHTS_PATH,'gan_C.h5'),
            'D': os.path.join(WEIGHTS_PATH,'gan_D.h5'),
            'G': os.path.join(WEIGHTS_PATH,'gan_G.h5'),
            'condition_encoder':os.path.join(WEIGHTS_PATH,'gan_condition_encoder.h5'),
            'classifier':os.path.join(WEIGHTS_PATH,'gan_classifier.h5')
            }
}

model = infr.InferenceModel( model_weigth_paths ) 

## 3. Read profiles and apply quantile transformer

In [11]:
data_oe = pd.read_csv(args['filename_oe_profiles'], index_col=0 )
data_oe = data_oe.drop(columns=['ORF Sequence', 'Quality Control','Median Replicate Correlation','Morphological Cluster ID'])
dmso_idx = data_oe['Control Type'] =='negative control'
data_oe.loc[dmso_idx,"Gene_Symbol"] = "DMSO"
keep_idx = data_oe.Gene_Symbol.isin(args["selected_genes"])
data_oe = data_oe[keep_idx].reset_index(drop=True)
feature_cols , info_cols = utils.get_feature_cols(data_oe)
logging.info('Number of targets: %i'%len(data_oe.Gene_Symbol.unique()))
logging.info('Targets: %a'%data_oe.Gene_Symbol.unique())

# Apply quantile transformer
quantile_transformer =  pickle.load( open( os.path.join(WEIGHTS_PATH,'quantile_transformer.pkl'), 'rb' ) )
data_oe[feature_cols] = quantile_transformer.transform(data_oe[feature_cols].values)   
logging.info('Number of Morphological profiles: %i'%data_oe.shape[0])

INFO - Number of targets: 21
INFO - Targets: array(['BRCA1', 'AKT3', 'ATF4', 'HSPA5', 'RAF1', 'MEK1', 'BRAF', 'TP53',
       'HIF1A', 'PIK3CD', 'JUN', 'NFKBIA', 'PDPK1', 'CEBPA', 'CREBBP',
       'STAT1', 'STAT3', 'RELB', 'NFKB1', 'WWTR1', 'PRKAA1'], dtype=object)
INFO - Number of Morphological profiles: 216


## 4. Generate a fixed number of valid, pass_physchem_filter = true molecules per gene

In [12]:
physchem_filterer = utils.PhysChemFilters(args['PhysChemFilter_alerts_file'])

for gene in args["selected_genes"]:
    
    logging.info("------ %s ---------"%gene)

    # Define output file and check that it doesnt aleady exits
    output_file = os.path.join(args["output_dir"], gene+"__"+str(args["N_valid_per_gene"])+"_"+filters_str+".csv")
    if os.path.isfile(output_file):
        logging.warning("File %s already exists. Skipping it \n"%output_file)
        continue
    print(output_file)
    # Generate molecules
    gene_data = data_oe.loc[ data_oe["Gene_Symbol"] == gene ].reset_index(drop=True)
    if args["PassPhysChemFilter"]:
        N_per_condition = 2000
        if gene == "DMSO": # we have one order of magnitude more samples for DMSO than all genes
            N_per_condition = 200
    else:
        N_per_condition = 500

    generated_final = pd.DataFrame()
    i = 0    
    while len(generated_final) < args["N_valid_per_gene"]:
        temp_generated = infr.generate_compounds_multiple_conditions( model, gene_data, feature_cols, info_cols, seed=i+seed_addition, nsamples = N_per_condition)
        temp_generated = infr.filter_valid_and_unique(temp_generated, cond_ID_cols=["Gene_Symbol"], select_unique=False)
        if args["PassPhysChemFilter"]:
            temp_generated["pass_physchem_filter"] =  temp_generated["SMILES_standard"].apply(lambda x: physchem_filterer.apply_filters(x))
            temp_generated = temp_generated.loc[ temp_generated["pass_physchem_filter"] ==True ]
        generated_final = pd.concat([generated_final, temp_generated])
        logging.info("%s iteration %i: %i valid molecules, PassPhysChemFilter=%r "%(gene,i,len(generated_final),args["PassPhysChemFilter"]) )
        i = i +1
    generated_final = generated_final.reset_index(drop=True)
    generated_final = generated_final.iloc[0:args["N_valid_per_gene"]]
    
    # Add Synthetic accessibility and Drug likeness scores
    rdkit_mols = [ Chem.MolFromSmiles(x) for x in generated_final.SMILES_standard ]
    generated_final["SA_score"]  = [ sascorer.calculateScore(x) for x in rdkit_mols]
    generated_final["QED_score"] = [ qed(x) for x in rdkit_mols]
    
    # Save results
    generated_final.to_csv(output_file)
    

INFO - ------ TP53 ---------

INFO - ------ BRCA1 ---------

INFO - ------ NFKB1 ---------

INFO - ------ HSPA5 ---------

INFO - ------ CREBBP ---------

INFO - ------ STAT1 ---------

INFO - ------ STAT3 ---------

INFO - ------ HIF1A ---------

INFO - ------ NFKBIA ---------

INFO - ------ JUN ---------

INFO - ------ PRKAA1 ---------

INFO - ------ PDPK1 ---------

INFO - ------ RAF1 ---------

INFO - ------ JUN ---------

INFO - ------ ATF4 ---------

INFO - ------ BRAF ---------

INFO - ------ CEBPA ---------

INFO - ------ RELB ---------

INFO - ------ MEK1 ---------

INFO - ------ PIK3CD ---------

INFO - ------ AKT3 ---------

INFO - ------ WWTR1 ---------

