In [1]:
import pandas as pd
from rdkit.Chem import Descriptors as Des
from rdkit import Chem

import numpy as np
from rdkit.Chem import MACCSkeys, DataStructs
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

In [None]:
Test_1 = pd.read_csv("Test_1.csv")
Test_1["split"]="Test_1"
Test_2 = pd.read_csv("Test_2.csv")
Test_2["split"]="Test_2"
Train_Val = pd.read_csv("Train_Val.csv")
Train_Val['split'] = Train_Val['split'].replace({'train': 'Train', 'test': 'Validation'})

In [None]:
data = pd.concat([Train_Val[Train_Val["split"]=="Train"],Test_2])

In [48]:
data["MV"]=data["smiles"].apply(lambda x : Des.ExactMolWt(Chem.MolFromSmiles(x)))

In [49]:
data = data[(data['MV'] >= 300) & (data['MV'] <= 500)]

In [None]:

def calculate_maccs(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return MACCSkeys.GenMACCSKeys(mol)
    else:
        return None

data['MACCS'] = data['smiles'].apply(calculate_maccs)
maccs_fps = [fp for fp in data['MACCS'] if fp is not None]
split_labels = data['split'][data['MACCS'].notnull()].values

def calculate_similarity_matrix(fps):
    num_fps = len(fps)
    similarity_matrix = np.zeros((num_fps, num_fps))
    for i in range(num_fps):
        for j in range(num_fps):
            similarity_matrix[i, j] = DataStructs.TanimotoSimilarity(fps[i], fps[j])
    return similarity_matrix

similarity_matrix = calculate_similarity_matrix(maccs_fps)

plt.figure(figsize=(12, 10), dpi=300)

unique_splits = np.unique(split_labels)
colors = sns.color_palette("husl", len(unique_splits))
split_color_map = dict(zip(unique_splits, colors))

ax = sns.heatmap(similarity_matrix, cmap='viridis', xticklabels=False, yticklabels=False, cbar=True)

for i, label in enumerate(split_labels):
    color = split_color_map[label]
    ax.add_patch(plt.Rectangle((i, -0.5), 1, 0.5, color=color, clip_on=False))

plt.title('Heatmap of MACCS Fingerprint Similarities with Split Annotations')
plt.xlabel('MACCS Fingerprint Index')
plt.ylabel('MACCS Fingerprint Index')

from matplotlib.lines import Line2D
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=split)
                   for split, color in split_color_map.items()]
plt.legend(handles=legend_elements, title='Split', bbox_to_anchor=(1.05, 1), loc='upper right')

plt.tight_layout()
plt.show()
