In [1]:
from pathlib import Path as pt
from uncertainties import ufloat, ufloat_fromstr
import umap
import plotly.express as px

import pandas as pd
import numpy as np
from rdkit import Chem
from collections import Counter
import umap
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
import plotly.graph_objects as go
from typing import List, Dict, Tuple
import warnings
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

# from tqdm.notebook import tqdm as tqdm_notebook

warnings.filterwarnings('ignore')
tqdm.pandas()

In [6]:
from load_data import processed_data_dirs, plots_dir, property_units, property_names, titles
from umda_viz import ChemicalClusterAnalyzer
processed_data_dirs

[PosixPath('/Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/[PHYSICAL CONSTANTS OF ORGANIC COMPOUNDS]/tmp_C_processed_data/analysis_data/filtered/tmpC_topelements_processed_data'),
 PosixPath('/Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/[PHYSICAL CONSTANTS OF ORGANIC COMPOUNDS]/tbp_C_processed_data/analysis_data/filtered/tbp_topelements_processed_data'),
 PosixPath('/Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/[PHYSICAL CONSTANTS OF ORGANIC COMPOUNDS]/vp_kPa_25C_filtered_ydata_processed_data/analysis_data/filtered/vp_kPa_25C_topelements_processed_data'),
 PosixPath('/Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/[CRITICAL CONSTANTS OF ORGANIC COMPOUNDS]/Pc_MPa_processed_data'),
 PosixPath('/Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology

In [22]:
def add_cluster_annotations(fig: go.Figure, df_plot: pd.DataFrame, labels: np.ndarray) -> go.Figure:
    # Optional: Add cluster centers with their statistics
    for cluster_id in set(labels):
        if cluster_id != -1:  # Skip noise points
            cluster_mask = df_plot['Cluster'] == cluster_id
            cluster_data = df_plot[cluster_mask]
            
            # Calculate cluster center and mean melting point
            center_x = cluster_data['UMAP1'].mean()
            center_y = cluster_data['UMAP2'].mean()
            mean_mp = cluster_data['molecular_property'].mean()
            std_mp = cluster_data['molecular_property'].std()
            
            # Add annotation for cluster statistics
            fig.add_annotation(
                x=center_x,
                y=center_y,
                text=f"Cluster {cluster_id}<br>MP: {mean_mp:.1f}±{std_mp:.1f}",
                showarrow=True,
                arrowhead=1,
                bgcolor='rgba(255,255,255,0.8)',
                bordercolor='black',
                borderwidth=1
            )
    return fig

def plot_figure(
        reduced_embeddings: np.ndarray, 
        labels: np.ndarray, 
        y: np.ndarray, 
        smiles_list: List[str], 
        property_name_with_unit: str, 
        property_name:str, 
        fname: str,
        save: bool = False
    ) -> go.Figure:
    
    # Create DataFrame for plotting
    df_plot = pd.DataFrame({
        'UMAP1': reduced_embeddings[:, 0],
        'UMAP2': reduced_embeddings[:, 1],
        'Cluster': labels,
        'SMILES': smiles_list,
        'molecular_property': y  # Adding melting point values
    })

    # Optional: Add statistical analysis
    print(f"\nStatistical Analysis of {property_name_with_unit} by Cluster:")
    cluster_stats = []
    for cluster_id in sorted(set(labels)):
        if cluster_id != -1:
            cluster_data = df_plot[df_plot['Cluster'] == cluster_id]
            stats = {
                'Cluster': cluster_id,
                'Size': len(cluster_data),
                'Mean_MP': cluster_data['molecular_property'].mean(),
                'Std_MP': cluster_data['molecular_property'].std(),
                'Min_MP': cluster_data['molecular_property'].min(),
                'Max_MP': cluster_data['molecular_property'].max()
            }
            cluster_stats.append(stats)

    stats_df = pd.DataFrame(cluster_stats)
    print("\nCluster Statistics:")
    print(stats_df.round(2))

    # Create figure
    fig = go.Figure()

    # Create a continuous color scale
    colorscale = 'Viridis'  # You can also try 'RdBu', 'Jet', 'Turbo', etc.

    # Add all points in a single scatter plot for continuous color mapping
    fig.add_trace(go.Scatter(
        x=df_plot['UMAP1'],
        y=df_plot['UMAP2'],
        mode='markers',
        marker=dict(
            size=8,
            color=df_plot['molecular_property'],
            colorscale=colorscale,
            colorbar=dict(title=property_name_with_unit),
            showscale=True
        ),
        text=[
            f"SMILES: {s}<br>" +
            f"{property_name}: {mp:.2f}<br>" +
            f"Cluster: {c}"
            for s, mp, c in zip(df_plot['SMILES'], df_plot['molecular_property'], df_plot['Cluster'])
        ],
        hoverinfo='text'
    ))

    # Update layout
    fig.update_layout(
        title=f'Chemical Structure Space Colored by {property_name_with_unit}',
        template='plotly_white',
        width=1200,
        height=800,
        showlegend=False,
        hovermode='closest',
        xaxis_title='UMAP1',
        yaxis_title='UMAP2'
    )

    # Add a text annotation with statistics
    mp_stats = f"""
    {property_name_with_unit} Statistics:
    Mean: {y.mean():.2f}
    Std: {y.std():.2f}
    Min: {y.min():.2f}
    Max: {y.max():.2f}
    """

    fig.add_annotation(
        x=0.02,
        y=0.98,
        xref='paper',
        yref='paper',
        text=mp_stats,
        showarrow=False,
        font=dict(size=12),
        bgcolor='white',
        bordercolor='black',
        borderwidth=1,
        align='left'
    )

    # fig = add_cluster_annotations(fig, df_plot, labels)

    # save = True
    # save = False

    if save:
        save_path = plots_dir / f'{fname}_umap_property.html'
        fig.write_html(save_path)
        print(f"Property visualization saved to: {save_path}")

    fig.show()
    
    return fig

In [None]:
def compute_umap(ind: int):
    embeddings = 'mol2vec'
    property_name = property_names[ind]
    property_unit = property_units[ind]
    title = titles[ind]
    property_name_with_unit = f'{property_name} ({property_unit})'
    print(property_name_with_unit, title)

    current_dir = processed_data_dirs[ind]
    fname = current_dir.name.replace('_processed_data', '')
    csv_file = current_dir.parent / f'{fname}.csv'
    print(csv_file.exists(), csv_file.name)

    df = pd.read_csv(csv_file)
    smiles_list = df['SMILES'].to_list()
    print(len(smiles_list), 'smiles')

    vec_dir = current_dir / f'embedded_vectors/processed_{embeddings}_embeddings'
    print(vec_dir.exists(), vec_dir.name)

    X = np.load(vec_dir / 'processed.X.npy', allow_pickle=True)
    y = np.load(vec_dir / 'processed.y.npy', allow_pickle=True)
    print(X.shape, y.shape)


    # Scale embeddings
    print("Scaling embeddings...")
    scaler = StandardScaler()
    scaled_embeddings = scaler.fit_transform(X)

    # Perform UMAP
    print("Performing UMAP...")

    n_neighbors = 15
    min_dist = 0.1
    n_components = 2

    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        n_jobs=-1,
        # random_state=random_state
    )
        
    reduced_embeddings = reducer.fit_transform(scaled_embeddings)
    print(reduced_embeddings.shape)

    cluster_eps: float = 0.5
    cluster_min_samples: int = 5

    print("Analyzing chemical clusters...")
    analyzer = ChemicalClusterAnalyzer()
    labels, cluster_analysis = analyzer.analyze_cluster_chemistry(
        reduced_embeddings,
        smiles_list,
        eps=cluster_eps,
        min_samples=cluster_min_samples
    )
    
    return reduced_embeddings, labels, cluster_analysis, y, smiles_list, property_name_with_unit, property_name, fname

ind = 0
for ind in tqdm(range(0, 5)):
    print(f'Processing {ind}')
    reduced_embeddings, labels, cluster_analysis, y, smiles_list, property_name_with_unit, property_name, fname = compute_umap(ind)
    fig = plot_figure(
        reduced_embeddings=reduced_embeddings,
        labels=labels,
        y=y,
        smiles_list=smiles_list,
        property_name_with_unit=property_name_with_unit,
        property_name=property_name,
        fname=fname,
        # save=True
    )
    # fig.show()

  0%|          | 0/1 [00:00<?, ?it/s]

Processing 0
Melting Point (K) MP
True tmpC_topelements.csv
7476 smiles
True processed_mol2vec_embeddings
(7476, 300) (7476,)
Scaling embeddings...
Performing UMAP...
(7476, 2)
Analyzing chemical clusters...


100%|██████████| 23/23 [00:01<00:00, 18.32it/s]


Statistical Analysis of Melting Point (K) by Cluster:

Cluster Statistics:
    Cluster  Size  Mean_MP  Std_MP  Min_MP  Max_MP
0         0  1583   146.50   84.55  -30.43   492.0
1         1  1203   117.86  101.66 -177.20   400.0
2         2  1633    91.14   90.14 -101.30   410.0
3         3    82    67.85   63.66  -56.40   309.0
4         4  1980    28.69   99.18 -205.10   328.0
5         5    70   140.62   87.78    4.50   357.0
6         6   223    36.11   94.74 -137.36   335.0
7         7    70   107.60   46.58   25.80   205.0
8         8    44   152.21   40.73   80.00   255.0
9         9    36   206.86   76.33   37.00   420.0
10       10    28   197.84   32.23  121.00   240.0
11       11    16   172.09  111.70   40.00   485.0
12       12    17   131.25   51.28   65.00   230.4
13       13    18    37.76   38.20  -14.87   168.5
14       14    65    77.22   67.49  -45.10   224.0
15       15   211   101.46   70.98  -92.70   303.0
16       16    20   138.15   78.35   46.50   329.0
17    




100%|██████████| 1/1 [00:03<00:00,  3.77s/it]


In [28]:
reduced_embeddings, labels, cluster_analysis, y, smiles_list, property_name_with_unit, property_name, fname = compute_umap(4)
fig = plot_figure(
    reduced_embeddings=reduced_embeddings,
    labels=labels,
    y=y,
    smiles_list=smiles_list,
    property_name_with_unit=property_name_with_unit,
    property_name=property_name,
    fname=fname,
    save=True
)

Critical Temperature (K) CT
True Tc_K.csv
819 smiles
True processed_mol2vec_embeddings
(819, 300) (819,)
Scaling embeddings...
Performing UMAP...
(819, 2)
Analyzing chemical clusters...


100%|██████████| 9/9 [00:00<00:00, 69.14it/s]


Statistical Analysis of Critical Temperature (K) by Cluster:

Cluster Statistics:
   Cluster  Size  Mean_MP  Std_MP  Min_MP  Max_MP
0        0   451   573.06  131.55  190.56   872.0
1        1   105   702.96   69.32  490.20   886.0
2        2    24   822.58   78.53  750.00  1115.0
3        3    40   617.23   86.66  398.20   787.0
4        4    41   599.10   47.14  516.40   685.7
5        5   105   577.44   62.61  407.81   850.0
6        6    43   554.80   52.16  433.71   708.0
7        7     9   428.03   48.44  361.80   486.5
Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/Tc_K_umap_property.html





In [41]:
def plot_figure_static(
    reduced_embeddings: np.ndarray, 
    labels: np.ndarray, 
    y: np.ndarray, 
    smiles_list: List[str], 
    property_name_with_unit: str, 
    property_name: str, 
    fname: str,
    save: bool = False,
    fig_size: tuple = (12, 8),
    point_size: int = 50,
    alpha: float = 0.6
) -> plt.Figure:
    """
    Create a static visualization of UMAP embeddings colored by molecular property using seaborn.
    """
    # Create DataFrame for plotting
    df_plot = pd.DataFrame({
        'UMAP1': reduced_embeddings[:, 0],
        'UMAP2': reduced_embeddings[:, 1],
        'Cluster': labels,
        'SMILES': smiles_list,
        'molecular_property': y
    })

    # Set up the matplotlib figure
    plt.clf()
    fig, ax = plt.subplots(figsize=fig_size)

    # Create the scatter plot using seaborn
    scatter = sns.scatterplot(
        data=df_plot,
        x='UMAP1',
        y='UMAP2',
        hue='molecular_property',
        palette='viridis',
        s=point_size,
        alpha=alpha,
        ax=ax
    )

    # Remove the automatic legend created by seaborn
    scatter.legend_.remove()

    # Create and customize the colorbar
    norm = plt.Normalize(y.min(), y.max())
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    sm.set_array([])
    colorbar = fig.colorbar(sm, ax=ax)  # Pass the axis to the colorbar
    colorbar.set_label(property_name_with_unit, fontsize=12)

    # Add title and labels
    ax.set_title(f'Chemical Structure Space Colored by {property_name_with_unit}', 
                fontsize=14, pad=20)
    ax.set_xlabel('UMAP1', fontsize=12)
    ax.set_ylabel('UMAP2', fontsize=12)

    # Add statistics annotation
    stats_text = (f"{property_name_with_unit} Statistics:\n"
                 f"Mean: {y.mean():.2f}\n"
                 f"Std: {y.std():.2f}\n"
                 f"Min: {y.min():.2f}\n"
                 f"Max: {y.max():.2f}")
    
    # Position the text box in figure coords
    props = dict(boxstyle='round', facecolor='white', alpha=0.8)
    ax.text(1.2, 0.98, stats_text,
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment='top',
            bbox=props)

    # # Optional: Add cluster centers and labels
    # for cluster_id in set(labels):
    #     if cluster_id != -1:
    #         cluster_mask = df_plot['Cluster'] == cluster_id
    #         cluster_data = df_plot[cluster_mask]
    #         center_x = cluster_data['UMAP1'].mean()
    #         center_y = cluster_data['UMAP2'].mean()
    #         mean_prop = cluster_data['molecular_property'].mean()
            
    #         # Add cluster label with mean property value
    #         # ax.annotate(f'Cluster {cluster_id}\n{mean_prop:.1f}',
    #         ax.annotate(f'{cluster_id}',
    #                    (center_x, center_y),
    #                    xytext=(5, 5),
    #                    textcoords='offset points',
    #                    fontsize=8,
    #                 #    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')
    #                    )

    # Tight layout to prevent text clipping
    plt.tight_layout()

    # Save if requested
    if save:
        save_path = plots_dir / f'{fname}_umap_property_static.png'
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Property visualization saved to: {save_path}")

    return fig


# results = compute_umap(0)
# reduced_embeddings, labels, cluster_analysis, y, smiles_list, property_name_with_unit, property_name, fname = results
# static_fig = plot_figure_static(
#     reduced_embeddings=reduced_embeddings,
#     labels=labels,
#     y=y,
#     smiles_list=smiles_list,
#     property_name_with_unit=property_name_with_unit,
#     property_name=property_name,
#     fname=fname,
#     save=True
# )
# plt.show()



In [43]:
for ind in range(5):
    results = compute_umap(ind)
    reduced_embeddings, labels, cluster_analysis, y, smiles_list, property_name_with_unit, property_name, fname = results
    static_fig = plot_figure_static(
        reduced_embeddings=reduced_embeddings,
        labels=labels,
        y=y,
        smiles_list=smiles_list,
        property_name_with_unit=property_name_with_unit,
        property_name=property_name,
        fname=fname,
        save=True
    )
    # plt.show()
    plt.close()

Melting Point (K) MP
True tmpC_topelements.csv
7476 smiles
True processed_mol2vec_embeddings
(7476, 300) (7476,)
Scaling embeddings...
Performing UMAP...
(7476, 2)
Analyzing chemical clusters...


100%|██████████| 23/23 [00:01<00:00, 19.01it/s]


Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/tmpC_topelements_umap_property_static.png
Boiling Point (K) BP
True tbp_topelements.csv
4915 smiles
True processed_mol2vec_embeddings
(4915, 300) (4915,)
Scaling embeddings...
Performing UMAP...
(4915, 2)
Analyzing chemical clusters...


100%|██████████| 20/20 [00:00<00:00, 26.87it/s]


Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/tbp_topelements_umap_property_static.png
Vapor Pressure (kPa at 25°C ) VP
True vp_kPa_25C_topelements.csv
398 smiles
True processed_mol2vec_embeddings
(398, 300) (398,)
Scaling embeddings...
Performing UMAP...
(398, 2)
Analyzing chemical clusters...


100%|██████████| 7/7 [00:00<00:00, 111.77it/s]


Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/vp_kPa_25C_topelements_umap_property_static.png
Critical Pressure (MPa) CP
True Pc_MPa.csv
777 smiles
True processed_mol2vec_embeddings
(777, 300) (777,)
Scaling embeddings...
Performing UMAP...
(777, 2)
Analyzing chemical clusters...


100%|██████████| 10/10 [00:00<00:00, 84.10it/s]


Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/Pc_MPa_umap_property_static.png
Critical Temperature (K) CT
True Tc_K.csv
819 smiles
True processed_mol2vec_embeddings
(819, 300) (819,)
Scaling embeddings...
Performing UMAP...
(819, 2)
Analyzing chemical clusters...


100%|██████████| 9/9 [00:00<00:00, 68.05it/s]


Property visualization saved to: /Users/aravindhnivas/Library/CloudStorage/OneDrive-MassachusettsInstituteofTechnology/ML-properties/plots/Tc_K_umap_property_static.png


<Figure size 640x480 with 0 Axes>

In [None]:
# Create DataFrame for plotting
df_plot = pd.DataFrame({
    'UMAP1': reduced_embeddings[:, 0],
    'UMAP2': reduced_embeddings[:, 1],
    'Cluster': labels,
    'SMILES': smiles_list
})

# Create interactive plot
fig = go.Figure()

# Add scatter points for each cluster
for cluster_id in set(labels):
    cluster_data = df_plot[df_plot['Cluster'] == cluster_id]
    
    # Get dominant functional groups for hover text
    if cluster_id in cluster_analysis:
        top_groups = sorted(
            cluster_analysis[cluster_id]['functional_groups'].items(),
            key=lambda x: x[1],
            reverse=True
        )[:3]
        hover_text = [
            f"SMILES: {s}<br>Cluster: {cluster_id}<br>" +
            "<br>".join([f"{g}: {v:.1%}" for g, v in top_groups])
            for s in cluster_data['SMILES']
        ]
    else:
        hover_text = [f"SMILES: {s}<br>Cluster: Noise" for s in cluster_data['SMILES']]
    
    fig.add_trace(go.Scatter(
        x=cluster_data['UMAP1'],
        y=cluster_data['UMAP2'],
        mode='markers',
        name=f'Cluster {cluster_id}',
        text=hover_text,
        hoverinfo='text',
        marker=dict(size=8)
    ))

# Update layout
fig.update_layout(
    title='Chemical Structure Clusters Analysis',
    template='plotly_white',
    width=1200,
    height=800,
    showlegend=True,
    hovermode='closest'
)
save_path = plots_dir / f'{fname}_umap_clusters.html'
fig.write_html(save_path)
print(f"Cluster analysis plot saved to: {save_path}")
fig.show()

In [None]:
# Optional: Create a second visualization showing both property and clusters
fig2 = go.Figure()

# Create subplots for different clusters with continuous color mapping
for cluster_id in set(labels):
    cluster_data = df_plot[df_plot['Cluster'] == cluster_id]
    
    name = 'Noise' if cluster_id == -1 else f'Cluster {cluster_id}'
    
    fig2.add_trace(go.Scatter(
        x=cluster_data['UMAP1'],
        y=cluster_data['UMAP2'],
        mode='markers',
        name=name,
        marker=dict(
            size=8,
            color=cluster_data['Melting_Point'],
            colorscale=colorscale,
            showscale=True if cluster_id == list(set(labels))[-1] else False,  # Show colorbar only once
            colorbar=dict(title=property_name)
        ),
        text=[
            f"SMILES: {s}<br>" +
            f"{property_name}: {mp:.2f}<br>" +
            f"Cluster: {c}"
            for s, mp, c in zip(cluster_data['SMILES'], 
                              cluster_data['Melting_Point'], 
                              cluster_data['Cluster'])
        ],
        hoverinfo='text'
    ))

fig2.update_layout(
    title=f'Chemical Structure Clusters with {property_name} Distribution',
    template='plotly_white',
    width=1200,
    height=800,
    showlegend=True,
    hovermode='closest',
    xaxis_title='UMAP1',
    yaxis_title='UMAP2'
)

# fig2.show()

# Add melting point statistics per cluster
stats_text = f"{property_name} by Cluster:\n"
for cluster_id in sorted(set(labels)):
    if cluster_id != -1:
        cluster_data = df_plot[df_plot['Cluster'] == cluster_id]
        stats_text += f"\nCluster {cluster_id}:\n"
        stats_text += f"Mean: {cluster_data['Melting_Point'].mean():.2f}\n"
        stats_text += f"Std: {cluster_data['Melting_Point'].std():.2f}\n"

fig2.add_annotation(
    x=1.15,
    y=0.5,
    xref='paper',
    yref='paper',
    text=stats_text,
    showarrow=False,
    font=dict(size=10),
    align='left'
)

# save_path2 = plots_dir / f'{fname}_umap_property_clusters.html'
# fig2.write_html(save_path2)
# print(f"Property and clusters visualization saved to: {save_path2}")
# fig2.show()

In [None]:
import matplotlib.colors as mcolors

colors = list(mcolors.XKCD_COLORS.keys())

fig, ax = plt.subplots(figsize=(15, 8), dpi=200)

for cluster_id in cluster_analysis.keys():
    cluster_data = cluster_analysis[cluster_id]

    top_groups = sorted(
        cluster_analysis[cluster_id]['functional_groups'].items(),
        key=lambda x: x[1],
        reverse=True
    )[:3]
    # print(cluster_id, top_groups)
    
    legend_text = f"{cluster_id}: "
    for group, freq in top_groups:
        legend_text += f"{group} ({freq:.1%}), "
    legend_text = legend_text[:-2]
    
    sns.scatterplot(
        x=reduced_embeddings[labels == cluster_id][:, 0],
        y=reduced_embeddings[labels == cluster_id][:, 1],
        alpha=0.6,
        s=25,
        ax=ax,
        label=legend_text,
        color=colors[cluster_id],
    )
    
    # place text label on center of cluster
    center = cluster_data['center']
    ax.text(center[0], center[1], f'{cluster_id}', fontsize=12, color='black')

ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.set_title('UMAP Visualization of Chemical Structure Embeddings')
ax.set_xlabel('UMAP1')
ax.set_ylabel('UMAP2')

plt.tight_layout()
# save_path = plots_dir / f'{fname}_umap_clusters.pdf'
# fig.savefig(save_path, bbox_inches='tight')
# print(f"Saved plot to {save_path.name}")
plt.show()

In [None]:
plots_dir / f'{fname}_umap_clusters.png'

In [None]:
reduced_embeddings = None

In [None]:
n_neighbors: int = 15
min_dist: float = 0.1
cluster_eps: float = 0.5
cluster_min_samples: int = 5

# Scale embeddings
print("Scaling embeddings...")
scaler = StandardScaler()
scaled_embeddings = scaler.fit_transform(X)

# # Perform UMAP
if reduced_embeddings is None:
    print("Performing UMAP reduction...")
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=2,
        random_state=42
    )
    reduced_embeddings = reducer.fit_transform(scaled_embeddings)

# Analyze clusters
print("Analyzing chemical clusters...")
analyzer = ChemicalClusterAnalyzer()
labels, cluster_analysis = analyzer.analyze_cluster_chemistry(
    reduced_embeddings,
    smiles_list,
    eps=cluster_eps,
    min_samples=cluster_min_samples
)

# Create visualization
print("Creating visualization...")
analyzer.plot_cluster_analysis(
    reduced_embeddings,
    smiles_list,
    labels,
    cluster_analysis,
    # output_filepath=vec_dir / f'{fname}_cluster_analysis.html'
)

print("Analysis complete!")

In [None]:
df_plot = pd.DataFrame({
    'UMAP1': reduced_embeddings[:, 0],
    'UMAP2': reduced_embeddings[:, 1],
    'Cluster': labels,
    'SMILES': smiles_list
})

for cluster_id in set(labels):
    cluster_data = df_plot[df_plot['Cluster'] == cluster_id]

    # Get dominant functional groups for hover text
    if cluster_id in cluster_analysis:
        top_groups = sorted(
            cluster_analysis[cluster_id]['functional_groups'].items(),
            key=lambda x: x[1],
            reverse=True
        )[:3]
        print(cluster_id, top_groups)

In [None]:
labels