# Basic GTM Training for ChEMBL database

## Importing libraries

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from sklearn.datasets import make_s_curve

# GTM and utils
from gtmkit.gtm import GTM
from gtmkit.utils.molecules import calculate_latent_coords
from gtmkit.utils.regression import get_reg_density_matrix, reg_density_to_table
from gtmkit.utils.density import density_to_table

# Plotting
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401  # needed for 3D subplot

# Needed for plotply plots
import plotly.graph_objects as go
import ipywidgets as widgets
from ipywidgets import GridspecLayout, Box

import altair as alt
from gtmkit.plots.altair_landscapes import (
    altair_points_chart,
    altair_discrete_regression_landscape,
    altair_discrete_density_landscape
)
from gtmkit.plots.plotly_landscapes import plotly_smooth_density_landscape, plotly_smooth_regression_landscape


# Creation of the dataframe, descriptor calculation 

In [2]:
df = pd.read_parquet('ChEMBLdata/CHEMBL279_CLS.parquet')
df

Unnamed: 0,smi,ChEMBL id,class,mfp_r2_1024
0,Cc1ccc(F)c(NC(=O)Nc2ccc(cc2)-c2cccc3snc(N)c23)c1,CHEMBL271795,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,FC(F)C(=O)NCCOc1cc2ncnc(Nc3ccc(Br)cc3F)c2cc1NC...,CHEMBL3671495,1,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
2,Cc1n[nH]c2nccc(-c3ccc(NC(=O)Nc4cccc(Br)c4)cc3)c12,CHEMBL4079338,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,COc1cc2c(Nc3ccc(Cl)c(Cl)c3)ncnc2cc1OCC1CN(CCO1...,CHEMBL3982666,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4,Cc1cc(C(=O)Nc2cccc(c2)C(=O)c2ccc3c(C=Cc4ccccn4...,CHEMBL3974462+CHEMBL3891895,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
...,...,...,...,...
7992,OC(=O)CCCC(=O)Nc1cc(Nc2ccccc2Br)ncn1,CHEMBL4293070,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
7993,COc1cc2c(Nc3ccc(Cl)c(Cl)c3)ncnc2cc1OCC1CN(CCO1...,CHEMBL3940839,1,"[0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
7994,CONC(=O)c1ccccc1Sc1ccc2c(C=Cc3ccccn3)n[nH]c2c1,CHEMBL3960830,2,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
7995,C(CN1CCOCC1)Oc1ccc(cc1)-c1cc(ccn1)-c1c[nH]nc1-...,CHEMBL205158,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, ..."


In [3]:
df['class'].value_counts()

class
1    6844
2    1153
Name: count, dtype: int64

## Building the GTM

In [4]:
import sys, traceback

desc_col = 'mfp_r2_1024'
device = "cuda" if torch.cuda.is_available() else "cpu"

X_svm = np.stack(df[desc_col].sample(5000).values)
fingerprint_tensor = torch.tensor(X_svm, dtype=torch.float64, device=device)

gtm = GTM(
    num_nodes = 36**2,
    num_basis_functions = 31**2,
    basis_width= 5.726809,
    reg_coeff = 543.61223,
    max_iter=300,
    tolerance=0.001,
    standardize=False,
    seed=1234,
    device="cuda:0",
    pca_scale=True
)

try:
    gtm.fit(fingerprint_tensor)
except Exception as e1:
    print("There was an error, will try to fit by adding a small noise")
    print(f"Original error: {type(e1).__name__}: {e1}")          # message
    print(traceback.format_exc())
    try:    
        noise = np.random.normal(0, 1e-6, X_svm.shape)  # Mean = 0, Stddev = 1e-6
        A_noisy = X_svm + noise
        
        fingerprint_df_exception = pd.DataFrame(A_noisy)
        fingerprint_tensor_exception = torch.tensor(fingerprint_df_exception.values)
        
        gtm.fit(fingerprint_tensor_exception)
    except Exception:
        # If the model fails to fit, return a low score
        print(f"Still couldn't fit it. {Exception}")
        exit 

  0%|          | 0/300 [00:00<?, ?it/s]

## Projecting full dataset

In [5]:
desc_col = 'mfp_r2_1024'
X_svm = np.stack(df[desc_col].values)
fingerprint_tensor = torch.tensor(X_svm, dtype=torch.float64, device=device)
resps, llh = gtm.project(fingerprint_tensor)
resps = resps.cpu().numpy()
llh = llh.cpu().numpy()

### First, we plot using plotly to get a nice smooth image

In [9]:
from gtmkit.utils.classification import get_class_density_matrix, class_density_to_table
from gtmkit.plots.plotly_landscapes import plotly_discrete_class_landscape

classes_str=['1', '2']

density, class_density, class_prob = get_class_density_matrix(
    resps,
    class_labels=df["class"].to_list(),
    class_name=classes_str,
    normalize=True,
)

source_class = class_density_to_table(
    density=density,
    class_density=class_density,
    class_prob=class_prob,
    class_name=classes_str,
    normalized=True
    #output_csv_file="data/cdk4_gtm_class.csv"
)

source = density_to_table(
    density=density,
)

plotly_density = plotly_smooth_density_landscape(source, title="Density landscape")

plotly_CLASS = plotly_discrete_class_landscape(source_class, title=f'Class landscape (Inactive=0, Active=1)',
    first_class_density_column_name=classes_str[0]+"_norm_density",
    first_class_prob_column_name=classes_str[0]+"_norm_prob",
    second_class_density_column_name=classes_str[1]+"_norm_density",
    second_class_prob_column_name=classes_str[1]+"_norm_prob",
    first_class_label='Inactive ',
    second_class_label='Active ', min_density=0.1)

fig1_widget = go.FigureWidget(plotly_density)

fig2_widget = go.FigureWidget(plotly_CLASS)

widgets_box = widgets.HBox([fig1_widget, fig2_widget])
widgets_box

HBox(children=(FigureWidget({
    'data': [{'colorbar': {'title': {'text': 'Density'}},
              'colorsc…

### To get more details, we can plot the same densities, but using altair

In [11]:
from gtmkit.plots.altair_landscapes import altair_discrete_class_landscape

sourceAlt = density_to_table(
    density=density,
    node_threshold=0.1
)
sourceAlt_class = class_density_to_table(
    density=density,
    class_density=class_density,
    class_prob=class_prob,
    class_name=classes_str,
    normalized=True,
    node_threshold=0.1
    #output_csv_file="data/cdk4_gtm_class.csv"
)

chart_density = altair_discrete_density_landscape(sourceAlt, title='Density landscape')
chart_density.properties(
    width=400, 
    height=400,
).configure_legend(
    labelFontSize=20,
    gradientVerticalMaxLength=600,
    gradientThickness=30,
    tickCount=6
)

chart_class = altair_discrete_class_landscape(sourceAlt_class, title=f'Class landscape (Inactive=0, Active=1)',
    first_class_density_column_name=classes_str[0]+"_norm_density",
    first_class_prob_column_name=classes_str[0]+"_norm_prob",
    second_class_density_column_name=classes_str[1]+"_norm_density",
    second_class_prob_column_name=classes_str[1]+"_norm_prob",
    use_density=True, 
    colorset='redblue',
    reverse=True
).properties(
    width=600, 
    height=600,
)

combined_chart = (
    chart_density.properties(width=600, height=600) | chart_class.properties(width=600, height=600)
).resolve_scale(
    x='independent',
    y='independent',
    color='independent'
)

combined_chart

---