# Imports

In [None]:
%env OE_LICENSE=/Users/alexpayne/oe_license.txt

In [None]:
from openeye import oechem
from openeye import oedepict
from sklearn.cluster import KMeans
import numpy as np
import importlib
from covid_moonshot_ml.docking import docking as d
from covid_moonshot_ml.data import openeye as oe
from covid_moonshot_ml import schema
from covid_moonshot_ml.datasets import utils

In [None]:
importlib.reload(d)
importlib.reload(schema)
importlib.reload(utils)

# Exploring kmeans clustering

In [None]:
X = np.array([[1, 2], [1, 4], [1, 0], [10, 2], [10, 4], [10, 0]])

In [None]:
kmeans = KMeans(n_clusters=2, random_state=0).fit(X)

In [None]:
kmeans.labels_

In [None]:
kmeans.predict([[0,0], [10,0]])

In [None]:
kmeans.cluster_centers_

# Load in MCSS pickle

In [None]:
import pickle as pkl

In [None]:
pickle_path = "/Volumes/Rohirrim/local_test/chemiinformatics/mcs_sort_index.pkl"

In [None]:
with open(pickle_path, 'rb') as f:
    data = pkl.load(f)

In [None]:
len(data[0])

In [None]:
len(data[1])

In [None]:
len(data[2])

In [None]:
len(data[2][0])

In [None]:
582-220

In [None]:
data[2][0]

In [None]:
data[2][110]

In [None]:
X = data[2]

In [None]:
kmeans = KMeans(n_clusters=10, random_state=0).fit(X)

In [None]:
kmeans.labels_

In [None]:
kmeans.cluster_centers_

In [None]:
cluster_1 = np.array(data[0])[kmeans.labels_==1]

In [None]:
cmpd_to_label_dict = {data[0][i]: kmeans.labels_[i] for i in range(len(data[0]))}

In [None]:
cmpd_to_label_dict

In [None]:
cluster_0 = np.array(data[0])[kmeans.labels_==0]

## get smiles for each molecule

In [None]:
cmpd_tracker = "/Users/alexpayne/Scientific_Projects/mers-drug-discovery/Mpro_compound_tracker_csv.csv"
xtal_compounds = d.parse_xtal(cmpd_tracker, "")

In [None]:
len(xtal_compounds)

## examine compound tracker

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv(cmpd_tracker)

In [None]:
df.index = df["Compound ID"]

In [None]:
df.loc['ALP-POS-c3a96089-4', "Moonshot Series"]

## load cmpd_to_frag

In [None]:
import yaml

In [None]:
with open("../data/cmpd_to_frag.yaml") as f:
    cmpd_to_frag_dict = yaml.safe_load(f)

In [None]:
frag_to_cmpd = {frag: cmpd for cmpd, frag in cmpd_to_frag_dict.items()}

In [None]:
for xtal in xtal_compounds:
    xtal.compound_id=frag_to_cmpd[xtal.dataset]
    xtal.series=df.loc[xtal.compound_id, "Moonshot Series"]

In [None]:
len(xtal_compounds)

In [None]:
frag_data = utils.parse_fragalysis_data(cmpd_tracker, "", xtals_only=False)

In [None]:
df.loc["AGN-NEW-891393a6-1"]

In [None]:
set(df["Moonshot Series"])

In [None]:
cmpd_info = [frag_data.get(cmpd_id) for cmpd_id in data[0]]

In [None]:
sum(np.array(cmpd_info) != None)

In [None]:
cmpd_list = []
for cmpd_id in data[0]:
    try:
        cmpd_data = frag_data.get(cmpd_id)
        if cmpd_data:
            moonshot_series_name = df.loc[cmpd_id, "Moonshot Series"]
            if type(moonshot_series_name) != str:
                moonshot_series_name = moonshot_series_name[0]
            if moonshot_series_name == "None":
                moonshot_series_name = None 
            cmpd_data.series = moonshot_series_name
            cmpd_list.append(cmpd_data)
    except KeyError:
        continue

In [None]:
cmpd_list[0]

In [None]:
mol = oechem.OEGraphMol()
oechem.OESmilesToMol(mol, cmpd_list[0].smiles)
oe.write_openeye_ligand(mol, "test.png")

In [None]:
mapped_data = {}
for i in range(len(data[0])):
    cmpd_data = cmpd_info[i]
    if cmpd_data:
        mapped_data[cmpd_data.compound_id] = (cmpd_data, data[2][i])

In [None]:
len(mapped_data)

In [None]:
mapped_data["AAR-POS-5507155c-1"]

In [None]:
key_filter = [key for key in frag_data.keys() if "BEN-DND" in key]

In [None]:
key_filter.sort()
key_filter

In [None]:
data[0]

## write out first 10 structures for each cluster

In [None]:
from covid_moonshot_ml.data import openeye as oe

In [None]:
cluster = 7
for i in range(10):
    cmpd_id = np.array(data[0])[kmeans.labels_ == cluster][i]
    cmpd_data = mapped_data.get(cmpd_id)
    if cmpd_data:
        smiles = cmpd_data[0].smiles
        mol = oechem.OEGraphMol()
        oechem.OESmilesToMol(mol, smiles)
        oe.write_openeye_ligand(mol, 
                            out_fn=f"/Volumes/Rohirrim/local_test/chemiinformatics/cluster{cluster}_{cmpd_id}.png")

# Play with options for MCSS based clustering

In [None]:
cmpd_id = np.array(data[0])[kmeans.labels_ == cluster][0]
smiles = mapped_data[cmpd_id][0].smiles
mol1 = oechem.OEGraphMol()
oechem.OESmilesToMol(mol1, smiles)

## Generate molecule fingerprints

In [None]:
from openeye import oegraphsim

In [None]:
mol = oechem.OEGraphMol()
oechem.OESmilesToMol(mol, cmpd_list[0].smiles)
fp = oegraphsim.OEFingerPrint()
oegraphsim.OEMakeFP(fp, mol, oegraphsim.OEFPType_Circular)

In [None]:
fptype = oegraphsim.OEGetFPType(oegraphsim.OEFPType_Circular)

In [None]:
fptype.GetFPTypeString()

In [None]:
fp.IsValid()

In [None]:
fps = []
for cmpd_data in cmpd_list:
    mol = oechem.OEGraphMol()
    oechem.OESmilesToMol(mol, cmpd_data.smiles)
    fp = oegraphsim.OEFingerPrint()
    oegraphsim.OEMakeFP(fp, mol, oegraphsim.OEFPType_Circular)
    if fp.IsValid():
        fps.append(fp)
    
    else:
        print(cmpd_data.compound_id)
fps_array = np.array(fps)

In [None]:
fp.GetSize()

In [None]:
oebit = oechem.OEBitVector()
oechem.OEParseHex(oebit, fp.ToHexString())

In [None]:
{fp.GetSize() for fp in fps}

In [None]:
fp = fps[1]
print(fp.GetSize())

In [None]:
integer = int(hex_str[:-1], base=16)

In [None]:
len(str(integer))

In [None]:
binary_str = f"{integer:04096b}"

In [None]:
len(binary_str)

In [None]:
vector_array = np.array(list(map(int, binary_str)))

In [None]:
vector_array

In [None]:
fp.GetSize()

In [None]:
4096 / 

In [None]:
smiles = ['']

## actually generate vector arrays

In [None]:
def fp_to_vector_array(fp):
    integer = int(fp.ToHexString()[:-1], base=16)
    binary_str = f"{integer:04096b}"
    vector_array = np.array(list(map(int, binary_str)))
    return vector_array

In [None]:
vector_arrays = [fp_to_vector_array(fp) for fp in fps]

In [None]:
len(vector_arrays)

## Try clustering with fprints

In [None]:
kmeans = KMeans(n_clusters=10, random_state=0).fit(vector_arrays)

In [None]:
cmpd_ids = np.array([xtal.compound_id for xtal in cmpd_list])
cmpd_series = np.array([xtal.series for xtal in cmpd_list])

In [None]:
results_df = pd.DataFrame({"Compound_ID":cmpd_ids, "Series":cmpd_series, "Cluster":kmeans.labels_})

In [None]:
results_df[results_df.Cluster == 9]

In [None]:
set(results_df[results_df.Cluster == 0].Series)

In [None]:
set(results_df[results_df.Cluster == ].Series)

In [None]:
for cluster in results_df.Cluster:
    filtered_cmpd_ids = np.array(cmpd_ids)[kmeans.labels_ == cluster]
    for i in range(len(filtered_cmpd_ids)):
        cmpd_id = filtered_cmpd_ids[i]
        cmpd_data = mapped_data.get(cmpd_id)
        if cmpd_data:
            smiles = cmpd_data[0].smiles
            mol = oechem.OEGraphMol()
            oechem.OESmilesToMol(mol, smiles)
            oe.write_openeye_ligand(mol, 
                                out_fn=f"/Volumes/Rohirrim/local_test/chemiinformatics/cluster{cluster}_{cmpd_id}.png")