In [None]:
import joblib
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import random
import time
import datetime
from sklearn.metrics import r2_score
import numpy as np
from tqdm.notebook import tqdm


sys.path.append("../lib/")
from MoleculeRegressor import MoleculeRegressor,FP
from rbm_util import train_rbm


sys.path.append("../REINVENT/lib/")
from reinvent_wrapper import run_reinvent,run_reinvent_parallel
from anneal_util import r_qubo_sampling,random_state_sampling


# Regression
    - Calculate slope for a linear regression model

In [None]:
def get_model(cond):
    df = pd.read_csv(cond["database_path"])
    df = df[[cond["smiles_column"], cond["target_param_name"]]]
    df = df.dropna()

    # delete broken smiles
    _, available_index = FP.calc_fingerprint(df[cond["smiles_column"]].values)
    df = df[available_index]

    model = MoleculeRegressor()
    X = df[cond["smiles_column"]]
    y = np.array(df[cond["target_param_name"]])
    model.fit(X, y)
    p_y = model.predict(X)

    plt.figure(figsize=(4, 4))
    plt.scatter(y, p_y, s=3)

    return model

# RBM
    - Calculate user-preference potential

In [None]:
def run_rbm(start_smiles):
    fp_list,_=FP.calc_fingerprint(start_smiles)
    joblib.dump(fp_list,"data/fp.bin")
    batch_size=1

    all_dim=1024*2
    fp_dim=len(fp_list[0])

    while all_dim <= fp_dim:
        all_dim=all_dim*2

    print("RBM dimension:", all_dim)
    plt.figure()
    _=train_rbm(fp_path="data/fp.bin",
                        out_path="data/rbm_J.bin",
                        batch_size= batch_size,
                        all_units=all_dim,
                        use_gpu=False,
                        epochs = 300,
                        k=3,
                        check_reconstruction=3)


# anneal & DRL

In [None]:
def auto_search(model,save_path):
    #set qubo
    rbm_qubo=joblib.load("data/rbm_J.bin")
    model_qubo=model.coef_

    for r in [2**i for i in range(-4,5,1)]:
        
        print(r)
        result_path=save_path+"/"+str(r)+".bin"
        if os.path.exists(result_path):
            print("already done! skipped")
            continue

        #anneal and drl
        print("start anneal")

        state_list,eg_list,c_list=r_qubo_sampling(r,rbm_qubo,model_qubo)          


        sel_id_list=random_state_sampling(state_list,eg_list,n_sampling=cond["sample_num"])
        sel_fp_list=(np.array(state_list)[sel_id_list])
        sel_fp_list=list(sel_fp_list[:,:512])
    
        
        print("start DQN")
        integ_df=run_reinvent_parallel(sel_fp_list, 
                                       rein_dir='../REINVENT/',
                                       original_dir='../4_compound_extraction/',
                                       n_parallel=cond["sample_num"],
                                       gpu_num=2)

        #dump
        result_dict={}
        result_dict["r"]=r
        result_dict["anneal_result"]=state_list,eg_list,c_list
        result_dict["sel_fp_list"]=sel_fp_list
        result_dict["rbm_qubo"]=rbm_qubo
        result_dict["time"]=time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime())
        result_dict["df"]=integ_df
        #joblib.dump(result_dict,result_path,compress=9)
        return result_dict

In [None]:
#auto_search
cond={}
cond["type"]="anneal_results/"
cond["smiles_column"]="SMILES"
cond["start_smiles"]=["Cc1ccccc1","FC(F)(F)F"]
cond["target_param_name"]="MolLogP"
cond["database_path"]="../database/BradleyDoublePlusGoodMeltingPointDataset_logP.csv"
cond["sample_num"]=4

cond["start_smiles"]=["Cc1ccccc1","FC(F)(F)F"]
#["Cc1ccccc1"],
#["Cc1ccccc1","FC(F)(F)F","O=C(O)C"],

target_param_list="MolLogP"
#"TPSA"


In [None]:
#regression
model=get_model(cond)

In [None]:
#rbm
#run_rbm(cond["start_smiles"])

# Solution sampling and molecule generation
- Solutions are sampled by MCMC in this demo code

In [None]:
rbm_qubo=joblib.load("data/rbm_J.bin")
model_qubo=model.coef_

#sample minimums
r=1
state_list,eg_list,c_list=r_qubo_sampling(r,rbm_qubo,model_qubo)       
        

In [None]:
sel_id_list=random_state_sampling(state_list,eg_list,n_sampling=cond["sample_num"])
sel_fp_list=(np.array(state_list,dtype=int)[sel_id_list])
sel_fp_list=list(sel_fp_list[:,:512])

In [None]:
#drl
integ_df=run_reinvent_parallel(sel_fp_list, 
                               rein_dir='../REINVENT/',
                               original_dir='../4_compound_extraction/',
                               n_parallel=cond["sample_num"],
                               gpu_num=2)

In [None]:
integ_df