## Setup

In [None]:
# Utils
import os
from dateutil.relativedelta import relativedelta
import pickle

# Data
import numpy as np
import pandas as pd

# Viz
import matplotlib.pyplot as plt
import seaborn as sns
import tmap as tm
from faerun import Faerun
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches

# Cheminformatics
from skfp.fingerprints import MHFPFingerprint

from utils import load_common_config, load_pickle, save_pickle, load_tmap_coord, get_pmas_palette

pmas = get_pmas_palette()
pmas_cmap = ListedColormap(pmas)
binary_cmap = ListedColormap(['#000000', pmas[0]])

## Code

### Generate TMAP coordinates

In [None]:
# Determine the dataset name from the current working directory
DATASET = os.getcwd().split('/')[-1]

# Define base path for dataset
data_path = f'../../data/{DATASET}'
figures_path = f'../../figures/{DATASET}'

# Load common config file
config = load_common_config(f'../../data/common/datasets_config.json')

INITIAL_DATE = pd.to_datetime(config[DATASET]['initial_date'])
FINAL_DATE = pd.to_datetime(config[DATASET]['final_date'])
TIMESTEP = config[DATASET]['timestep']

# Load the main and aggregated data
df = pd.read_csv(f'{data_path}/data_aggregated.csv').sort_values(by='DATE').reset_index(drop=True)
df['DATE'] = pd.to_datetime(df['DATE'])

df['iteration'] = 0

df['iteration'] = 0
iteration = 0
current_date = INITIAL_DATE
while current_date < FINAL_DATE:
    iteration += 1
    next_date = current_date + relativedelta(months=TIMESTEP)
    next_smiles = df[(df['DATE'] >= current_date) & (df['DATE'] < next_date)]['SMILES'].unique()
    df.loc[df['SMILES'].isin(next_smiles), 'iteration'] = iteration
    current_date = next_date

In [None]:
fps = MHFPFingerprint(fp_size=2048, radius=2, isomeric_smiles=True, variant='raw_hashes').transform(df["SMILES"].values)
fingerprints = [tm.VectorUint(fp) for fp in fps]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.axis('off')

lf = tm.LSHForest(2048, 512)
lf.batch_add(fingerprints)
lf.index()

x, y, s, t, _ = tm.layout_from_lsh_forest(lf)
connections = [(a, b) for a, b in zip(s, t)]

sns.kdeplot(x=x, y=y, fill=False, color='black', levels=2, bw_adjust=0.1, zorder=0, alpha=0.3, gridsize=200)
for a, b, in connections:
    ax.plot([x[a], x[b]], [y[a], y[b]], color='black', lw=0.5, alpha=0.5, zorder=0)

tmap_coordinates = {
    'x': np.array(x),
    'y': np.array(y),
    's': np.array(s),
    't': np.array(t),
}

In [None]:
# # Uncomment to save
# output_path = f'../../data/{DATASET}/tmap_coordinates.pkl'
# save_pickle(tmap_coordinates, output_path)

### Generate interactive TMAP

In [None]:
x, y, s, t = load_tmap_coord(dataset=DATASET)
connections = [(a, b) for a, b in zip(s, t)]

tmap_df = pd.read_csv(f'{data_path}/data_aggregated.csv')
tmap_df['DATE'] = pd.to_datetime(tmap_df['DATE'])
tmap_df = tmap_df.reset_index()

bp_df = pd.read_csv(f'{data_path}/blueprint.csv')
endpoints = list(bp_df['PROPERTIES'].unique())

if os.path.exists(f'{data_path}/top_molecules.pkl'):
    # Load the definitions of good molecules and create binary columns based on this
    top_molecules =  load_pickle(f'{data_path}/top_molecules.pkl')
    for key in top_molecules['geometric'].keys():
        tmp_smiles = top_molecules['geometric'][key]
        tmap_df[key] = tmap_df['SMILES'].apply(lambda x: 1 if x in tmp_smiles else 0)

    top_categories = list(top_molecules['geometric'].keys())
else:
    print("⚠️  Please run the preprocessing script ('./scripts/1_preprocessing.sh') to generate the top_molecules.pkl file and then come back to this notebook.")

In [None]:
sorted_tmap_df = tmap_df.sort_values(by='DATE', ascending=True).copy()

new_index = sorted_tmap_df['index']
mapping_dict = {}
for i, id in enumerate(new_index):
    mapping_dict[id] = i

new_x = x[new_index]
new_y = y[new_index]
new_s = [mapping_dict.get(x, x) for x in s]
new_t = [mapping_dict.get(x, x) for x in t]

labels = [
    f"{smiles}__CPD_ID: {sorted_tmap_df['CPD_ID'].values[i]}__SERIES_ID: {sorted_tmap_df['SERIES_ID'].values[i]}__DATE: {sorted_tmap_df['DATE'].values[i]}"
    for i, smiles in enumerate(sorted_tmap_df['SMILES'].values)
]

series_labels = []
for i, id in enumerate(tmap_df['SERIES_ID'].value_counts().index):
    series_labels.append((i, id))

best_labels = [(0, 'No'), (1, 'Yes')]


sorted_tmap_df['SERIES_ID_encoded'] = sorted_tmap_df['SERIES_ID'].copy()
for new_value, old_value in series_labels:
    sorted_tmap_df['SERIES_ID_encoded'] = sorted_tmap_df['SERIES_ID_encoded'].replace(old_value, new_value)


faerun = Faerun(view="front", coords=False)
faerun.add_tree("Assay_tree", {"from": new_s, "to": new_t}, point_helper="Assay")
faerun.add_scatter(
    "Assay",
    {
        "x": new_x,
        "y": new_y,
        "c": [sorted_tmap_df[x].values for x in ['SERIES_ID_encoded'] + top_categories + endpoints],
        "labels": labels
    },
    shader="smoothCircle",
    point_scale=3,
    max_point_size=200,
    colormap=[pmas_cmap, 'RdYlGn'] + [binary_cmap]*len(top_categories) + ['RdYlGn']*len(endpoints),
    has_legend=True,
    categorical=[True, False] + [True]*len(top_categories) + [False]*len(endpoints),
    min_legend_label=[None],
    max_legend_label=[None],
    series_title=['SERIES_ID', 'Documentation'] + top_categories + endpoints,
    legend_labels=[series_labels, []] + [best_labels]*len(top_categories) + [[]]*len(endpoints),
)

faerun.plot(f"{data_path}/tmap", template='smiles', notebook_height=1)