In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, Descriptors
import os
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
from scipy.stats import ttest_ind, pearsonr

class MolecularEmbeddingVisualizer:
    """Visualize and compare embeddings before and after training"""
    
    def __init__(self, output_dir='./visualizations'):
        """Initialize visualizer
        
        Args:
            output_dir: Directory to save visualizations
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Set up nice color schemes
        self.colors = {
            'before': '#1f77b4',  # blue
            'after': '#ff7f0e',   # orange
            'diff': '#2ca02c'     # green
        }
        
        # Set up seaborn style
        sns.set_style('whitegrid')
        plt.rcParams.update({
            'font.size': 12,
            'axes.labelsize': 14,
            'axes.titlesize': 16,
            'xtick.labelsize': 12,
            'ytick.labelsize': 12,
            'legend.fontsize': 12,
            'figure.titlesize': 18
        })
    
    def load_analysis_data(self, before_prefix, after_prefix):
        """Load analysis data from CSV files
        
        Args:
            before_prefix: Prefix for before training files (without _properties.csv)
            after_prefix: Prefix for after training files (without _properties.csv)
            
        Returns:
            Dictionary of DataFrames with loaded data
        """
        # Load properties
        before_props = pd.read_csv(f"{before_prefix}_properties.csv", index_col=0)
        after_props = pd.read_csv(f"{after_prefix}_properties.csv", index_col=0)
        
        # Load features
        before_features = pd.read_csv(f"{before_prefix}_features.csv", index_col=0)
        after_features = pd.read_csv(f"{after_prefix}_features.csv", index_col=0)
        
        # Load functional groups
        before_funcs = pd.read_csv(f"{before_prefix}_functional_groups.csv", index_col=0)
        after_funcs = pd.read_csv(f"{after_prefix}_functional_groups.csv", index_col=0)
        
        # Load embeddings if available
        before_embeddings = None
        after_embeddings = None
        
        try:
            before_emb_file = f"{before_prefix.replace('_properties', '')}.npz"
            if os.path.exists(before_emb_file):
                before_data = np.load(before_emb_file)
                before_embeddings = before_data['embeddings']
                before_smiles = before_data['smiles']
        except Exception as e:
            print(f"Error loading before embeddings: {e}")
        
        try:
            after_emb_file = f"{after_prefix.replace('_properties', '')}.npz"
            if os.path.exists(after_emb_file):
                after_data = np.load(after_emb_file)
                after_embeddings = after_data['embeddings']
                after_smiles = after_data['smiles']
        except Exception as e:
            print(f"Error loading after embeddings: {e}")
        
        return {
            'before_props': before_props,
            'after_props': after_props,
            'before_features': before_features,
            'after_features': after_features,
            'before_funcs': before_funcs,
            'after_funcs': after_funcs,
            'before_embeddings': before_embeddings,
            'after_embeddings': after_embeddings
        }
    
    def compare_properties_distributions(self, before_props, after_props, 
                                        props_to_compare=None, figsize=(16, 14)):
        """Compare molecular property distributions before and after training
        
        Args:
            before_props: DataFrame with properties before training
            after_props: DataFrame with properties after training
            props_to_compare: List of properties to compare (default: MW, LogP, TPSA)
            figsize: Figure size (width, height)
            
        Returns:
            Figure with property distributions
        """
        if props_to_compare is None:
            props_to_compare = ['MW', 'LogP', 'TPSA', 'NumRotatableBonds', 'NumHeavyAtoms']
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        
        # Create grid layout
        nrows = (len(props_to_compare) + 1) // 2
        gs = gridspec.GridSpec(nrows, 2, figure=fig)
        
        # Plot each property
        for i, prop in enumerate(props_to_compare):
            if prop in before_props.columns and prop in after_props.columns:
                ax = fig.add_subplot(gs[i // 2, i % 2])
                
                # Get data
                before_data = before_props[prop].dropna()
                after_data = after_props[prop].dropna()
                
                # Histogram for before training
                sns.histplot(before_data, ax=ax, color=self.colors['before'], 
                            alpha=0.6, label='Before Training', kde=True)
                
                # Histogram for after training
                sns.histplot(after_data, ax=ax, color=self.colors['after'], 
                            alpha=0.6, label='After Training', kde=True)
                
                # Set labels
                ax.set_title(f"{prop} Distribution")
                ax.set_xlabel(prop)
                ax.set_ylabel("Count")
                ax.legend()
                
                # Add p-value if enough data
                if len(before_data) > 1 and len(after_data) > 1:
                    # Run t-test
                    tstat, pval = ttest_ind(before_data, after_data, equal_var=False)
                    ax.annotate(f"p-value: {pval:.4f}", xy=(0.05, 0.95), xycoords='axes fraction',
                               bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, "property_distributions.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Property distribution comparison saved to {fig_path}")
        return fig
    
    def compare_feature_prevalence(self, before_features, after_features, figsize=(14, 8)):
        """Compare structural feature prevalence before and after training
        
        Args:
            before_features: DataFrame with features before training
            after_features: DataFrame with features after training
            figsize: Figure size (width, height)
            
        Returns:
            Figure with feature prevalence
        """
        # Get common feature columns
        common_features = list(set(before_features.columns) & set(after_features.columns))
        
        # Calculate prevalence (as percentage)
        before_prev = before_features[common_features].mean() * 100
        after_prev = after_features[common_features].mean() * 100
        
        # Create DataFrame for plotting
        prev_df = pd.DataFrame({
            'Before': before_prev,
            'After': after_prev
        })
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot
        prev_df.plot(kind='bar', ax=ax, color=[self.colors['before'], self.colors['after']])
        
        # Set labels
        ax.set_title("Structural Feature Prevalence")
        ax.set_ylabel("Prevalence (%)")
        ax.set_ylim(0, 100)
        
        # Add value labels
        for container in ax.containers:
            ax.bar_label(container, fmt='%.1f%%', fontsize=10)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, "feature_prevalence.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Feature prevalence comparison saved to {fig_path}")
        return fig
    
    def compare_functional_groups(self, before_funcs, after_funcs, figsize=(14, 8)):
        """Compare functional group distributions before and after training
        
        Args:
            before_funcs: DataFrame with functional groups before training
            after_funcs: DataFrame with functional groups after training
            figsize: Figure size (width, height)
            
        Returns:
            Figure with functional group distributions
        """
        # Get common functional group columns
        common_funcs = list(set(before_funcs.columns) & set(after_funcs.columns))
        
        # Calculate average counts
        before_avg = before_funcs[common_funcs].mean()
        after_avg = after_funcs[common_funcs].mean()
        
        # Create DataFrame for plotting
        func_df = pd.DataFrame({
            'Before': before_avg,
            'After': after_avg
        })
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot
        func_df.plot(kind='bar', ax=ax, color=[self.colors['before'], self.colors['after']])
        
        # Set labels
        ax.set_title("Average Functional Group Counts")
        ax.set_ylabel("Average Count per Molecule")
        
        # Add value labels
        for container in ax.containers:
            ax.bar_label(container, fmt='%.2f', fontsize=10)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, "functional_group_counts.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Functional group count comparison saved to {fig_path}")
        return fig
    
    def analyze_embedding_differences(self, before_emb, after_emb, before_props, after_props,
                                    prop_of_interest='LogP', figsize=(16, 8)):
        """Analyze differences in embeddings and their relationship to properties
        
        Args:
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            before_props: DataFrame with properties before training
            after_props: DataFrame with properties after training
            prop_of_interest: Property to color points by
            figsize: Figure size (width, height)
            
        Returns:
            Figure with embedding analysis
        """
        if before_emb is None or after_emb is None:
            print("Embeddings not available for analysis")
            return None
        
        # Make sure we have the same number of embeddings
        if len(before_emb) != len(after_emb):
            # Take the smaller size
            min_size = min(len(before_emb), len(after_emb))
            before_emb = before_emb[:min_size]
            after_emb = after_emb[:min_size]
            before_props = before_props.iloc[:min_size]
            after_props = after_props.iloc[:min_size]
        
        # Create figure
        fig, axes = plt.subplots(1, 2, figsize=figsize)
        
        # Apply PCA to both embeddings
        pca = PCA(n_components=2)
        before_pca = pca.fit_transform(before_emb)
        after_pca = pca.fit_transform(after_emb)
        
        # Get property for coloring
        if prop_of_interest in before_props.columns:
            prop_values = before_props[prop_of_interest].values
            vmin, vmax = np.percentile(prop_values, [5, 95])
            
            # Plot before training
            sc1 = axes[0].scatter(before_pca[:, 0], before_pca[:, 1], 
                                c=prop_values, cmap='viridis', alpha=0.8,
                                vmin=vmin, vmax=vmax)
            axes[0].set_title(f"Before Training (PCA)")
            axes[0].set_xlabel("PC1")
            axes[0].set_ylabel("PC2")
            
            # Plot after training
            sc2 = axes[1].scatter(after_pca[:, 0], after_pca[:, 1], 
                                c=prop_values, cmap='viridis', alpha=0.8,
                                vmin=vmin, vmax=vmax)
            axes[1].set_title(f"After Training (PCA)")
            axes[1].set_xlabel("PC1")
            axes[1].set_ylabel("PC2")
            
            # Add colorbar
            cbar = plt.colorbar(sc1, ax=axes)
            cbar.set_label(prop_of_interest)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, f"embedding_pca_{prop_of_interest}.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Embedding PCA analysis saved to {fig_path}")
        return fig
    
    def property_bias_analysis(self, before_props, after_props, before_emb, after_emb, 
                              props_to_analyze=None, figsize=(16, 12)):
        """Analyze bias in embedding space with respect to molecular properties
        
        Args:
            before_props: DataFrame with properties before training
            after_props: DataFrame with properties after training
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            props_to_analyze: List of properties to analyze
            figsize: Figure size (width, height)
            
        Returns:
            Figure with property bias analysis
        """
        if before_emb is None or after_emb is None:
            print("Embeddings not available for analysis")
            return None
        
        if props_to_analyze is None:
            props_to_analyze = ['MW', 'LogP', 'TPSA', 'NumHAcceptors', 'NumHDonors']
        
        # Make sure properties are in both DataFrames
        props_to_analyze = [p for p in props_to_analyze 
                          if p in before_props.columns and p in after_props.columns]
        
        # Apply PCA to before and after embeddings
        pca_before = PCA(n_components=10)
        pca_after = PCA(n_components=10)
        
        before_pca = pca_before.fit_transform(before_emb)
        after_pca = pca_after.fit_transform(after_emb)
        
        # Calculate correlations between PCA components and properties
        before_corrs = []
        after_corrs = []
        
        for prop in props_to_analyze:
            before_prop_vals = before_props[prop].values
            after_prop_vals = after_props[prop].values
            
            # Correlations with before training PCA components
            before_prop_corrs = [pearsonr(before_pca[:, i], before_prop_vals)[0] 
                               for i in range(before_pca.shape[1])]
            
            # Correlations with after training PCA components
            after_prop_corrs = [pearsonr(after_pca[:, i], after_prop_vals)[0] 
                              for i in range(after_pca.shape[1])]
            
            before_corrs.append(before_prop_corrs)
            after_corrs.append(after_prop_corrs)
        
        # Create figure
        fig, axes = plt.subplots(len(props_to_analyze), 1, figsize=figsize)
        if len(props_to_analyze) == 1:
            axes = [axes]
        
        # Plot correlations for each property
        for i, prop in enumerate(props_to_analyze):
            ax = axes[i]
            
            # Bar plot of correlations
            x = np.arange(len(before_corrs[i]))
            width = 0.35
            
            ax.bar(x - width/2, np.abs(before_corrs[i]), width, label='Before Training',
                  color=self.colors['before'])
            ax.bar(x + width/2, np.abs(after_corrs[i]), width, label='After Training',
                  color=self.colors['after'])
            
            # Set labels
            ax.set_title(f"Correlation of PCA Components with {prop}")
            ax.set_xlabel("PCA Component")
            ax.set_ylabel("|Correlation|")
            ax.set_xticks(x)
            ax.set_xticklabels([f"PC{j+1}" for j in range(len(before_corrs[i]))])
            ax.legend()
            
            # Add correlation values
            for j, (b_corr, a_corr) in enumerate(zip(before_corrs[i], after_corrs[i])):
                ax.annotate(f"{b_corr:.2f}", xy=(j - width/2, np.abs(b_corr) + 0.02), 
                          ha='center', va='bottom', fontsize=8, rotation=90)
                ax.annotate(f"{a_corr:.2f}", xy=(j + width/2, np.abs(a_corr) + 0.02), 
                          ha='center', va='bottom', fontsize=8, rotation=90)
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, "property_bias_analysis.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Property bias analysis saved to {fig_path}")
        return fig
    
    def analyze_correlations(self, before_props, after_props):
        """Analyze correlations between properties before and after training
        
        Args:
            before_props: DataFrame with properties before training
            after_props: DataFrame with properties after training
            
        Returns:
            DataFrame with correlation changes
        """
        # Get common property columns (excluding non-numeric)
        common_props = []
        for col in before_props.columns:
            if col in after_props.columns:
                if np.issubdtype(before_props[col].dtype, np.number) and np.issubdtype(after_props[col].dtype, np.number):
                    common_props.append(col)
        
        # Calculate correlations
        before_corr = before_props[common_props].corr()
        after_corr = after_props[common_props].corr()
        
        # Calculate correlation differences
        corr_diff = after_corr - before_corr
        
        # Create figure with three subplots
        fig, axes = plt.subplots(1, 3, figsize=(24, 8))
        
        # Plot before correlations
        sns.heatmap(before_corr, ax=axes[0], cmap='coolwarm', vmin=-1, vmax=1, 
                  annot=True, fmt='.2f', square=True, cbar=True)
        axes[0].set_title("Property Correlations Before Training")
        
        # Plot after correlations
        sns.heatmap(after_corr, ax=axes[1], cmap='coolwarm', vmin=-1, vmax=1, 
                  annot=True, fmt='.2f', square=True, cbar=True)
        axes[1].set_title("Property Correlations After Training")
        
        # Plot correlation differences
        sns.heatmap(corr_diff, ax=axes[2], cmap='coolwarm', vmin=-0.5, vmax=0.5, 
                  annot=True, fmt='.2f', square=True, cbar=True)
        axes[2].set_title("Correlation Changes (After - Before)")
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, "property_correlations.png")
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        print(f"Property correlation analysis saved to {fig_path}")
        
        return corr_diff
    
    def generate_comprehensive_report(self, before_prefix, after_prefix, embedding_before=None, embedding_after=None):
        """Generate a comprehensive report with all analyses
        
        Args:
            before_prefix: Prefix for before training files (without _properties.csv)
            after_prefix: Prefix for after training files (without _properties.csv)
            embedding_before: Path to before training embeddings (optional)
            embedding_after: Path to after training embeddings (optional)
            
        Returns:
            None (all figures saved to output_dir)
        """
        # Load data
        print("Loading data...")
        data = self.load_analysis_data(before_prefix, after_prefix)
        
        # Try to load embeddings if provided
        before_embeddings = data['before_embeddings']
        after_embeddings = data['after_embeddings']
        
        if embedding_before is not None and before_embeddings is None:
            try:
                before_data = np.load(embedding_before)
                before_embeddings = before_data['embeddings']
            except Exception as e:
                print(f"Error loading before embeddings: {e}")
        
        if embedding_after is not None and after_embeddings is None:
            try:
                after_data = np.load(embedding_after)
                after_embeddings = after_data['embeddings']
            except Exception as e:
                print(f"Error loading after embeddings: {e}")
        
        # Generate visualizations
        print("\nGenerating property distribution comparison...")
        self.compare_properties_distributions(
            data['before_props'], data['after_props'],
            props_to_compare=['MW', 'LogP', 'TPSA', 'NumHAcceptors', 'NumHDonors', 'NumRotatableBonds']
        )
        
        print("\nGenerating feature prevalence comparison...")
        self.compare_feature_prevalence(
            data['before_features'], data['after_features']
        )
        
        print("\nGenerating functional group comparison...")
        self.compare_functional_groups(
            data['before_funcs'], data['after_funcs']
        )
        
        print("\nGenerating property correlation analysis...")
        self.analyze_correlations(
            data['before_props'], data['after_props']
        )
        
        if before_embeddings is not None and after_embeddings is not None:
            print("\nGenerating embedding analysis for LogP...")
            self.analyze_embedding_differences(
                before_embeddings, after_embeddings,
                data['before_props'], data['after_props'],
                prop_of_interest='LogP'
            )
            
            print("\nGenerating embedding analysis for MW...")
            self.analyze_embedding_differences(
                before_embeddings, after_embeddings,
                data['before_props'], data['after_props'],
                prop_of_interest='MW'
            )
            
            print("\nGenerating property bias analysis...")
            self.property_bias_analysis(
                data['before_props'], data['after_props'],
                before_embeddings, after_embeddings
            )
        
        print(f"\nAll visualizations saved to {self.output_dir}")

        
def analyze_embedding_property_correlations(self, before_emb, after_emb, properties_df, 
                                           props_to_analyze=None):
    """Analyze how properties correlate with embedding dimensions
    
    Args:
        before_emb: Embeddings before training
        after_emb: Embeddings after training
        properties_df: DataFrame with molecular properties
        props_to_analyze: List of properties to analyze
        
    Returns:
        Figure showing correlation changes
    """
    if props_to_analyze is None:
        props_to_analyze = ['MW', 'LogP', 'TPSA', 'NumHAcceptors', 'NumHDonors']
    
    # Select top principal components
    pca = PCA(n_components=5)
    before_pca = pca.fit_transform(before_emb)
    
    pca = PCA(n_components=5)
    after_pca = pca.fit_transform(after_emb)
    
    # Calculate correlations between properties and PCA components
    before_correlations = {}
    after_correlations = {}
    
    for prop in props_to_analyze:
        if prop in properties_df.columns:
            prop_values = properties_df[prop].values
            
            before_corrs = []
            after_corrs = []
            
            for i in range(5):  # For top 5 PCA components
                before_corr = np.abs(np.corrcoef(before_pca[:, i], prop_values)[0, 1])
                after_corr = np.abs(np.corrcoef(after_pca[:, i], prop_values)[0, 1])
                
                before_corrs.append(before_corr)
                after_corrs.append(after_corr)
            
            before_correlations[prop] = before_corrs
            after_correlations[prop] = after_corrs
    
    # Create figure
    fig, axes = plt.subplots(len(props_to_analyze), 1, figsize=(12, 3*len(props_to_analyze)))
    
    for i, prop in enumerate(props_to_analyze):
        if prop in before_correlations:
            ax = axes[i] if len(props_to_analyze) > 1 else axes
            
            x = np.arange(5)
            width = 0.35
            
            ax.bar(x - width/2, before_correlations[prop], width, label='Before Training')
            ax.bar(x + width/2, after_correlations[prop], width, label='After Training')
            
            ax.set_title(f"Correlation with {prop}")
            ax.set_xlabel("PCA Component")
            ax.set_ylabel("Absolute Correlation")
            ax.set_xticks(x)
            ax.set_xticklabels([f"PC{j+1}" for j in range(5)])
            ax.legend()
    
    plt.tight_layout()
    return fig        

def run_visualization(before_prefix, after_prefix, output_dir='./visualizations',
                     embedding_before=None, embedding_after=None):
    """Run visualization pipeline
    
    Args:
        before_prefix: Prefix for before training files (without _properties.csv)
        after_prefix: Prefix for after training files (without _properties.csv)
        output_dir: Directory to save visualizations
        embedding_before: Path to before training embeddings (optional)
        embedding_after: Path to after training embeddings (optional)
        
    Returns:
        None (all figures saved to output_dir)
    """
    visualizer = MolecularEmbeddingVisualizer(output_dir=output_dir)
    visualizer.generate_comprehensive_report(
        before_prefix, after_prefix, embedding_before, embedding_after
    )


if __name__ == "__main__":
    # Example usage
    run_visualization(
        before_prefix="./analysis/before_training_20250301_140313",
        after_prefix="./analysis/after_training_20250301_140344",
        output_dir="./visualizations",
        embedding_before="./embeddings/before_training_20250301_140313.npz",
        embedding_after="./embeddings/after_training_20250301_140344.npz"
    )

AttributeError: 'MolecularEmbeddingVisualizer' object has no attribute 'analyze_embedding_property_correlations'

In [2]:
def analyze_embedding_property_correlations(self, before_emb, after_emb, properties_df, 
                                           props_to_analyze=None):
    """Analyze how properties correlate with embedding dimensions
    
    Args:
        before_emb: Embeddings before training
        after_emb: Embeddings after training
        properties_df: DataFrame with molecular properties
        props_to_analyze: List of properties to analyze
        
    Returns:
        Figure showing correlation changes
    """
    if props_to_analyze is None:
        props_to_analyze = ['MW', 'LogP', 'TPSA', 'NumHAcceptors', 'NumHDonors']
    
    # Select top principal components
    pca = PCA(n_components=5)
    before_pca = pca.fit_transform(before_emb)
    
    pca = PCA(n_components=5)
    after_pca = pca.fit_transform(after_emb)
    
    # Calculate correlations between properties and PCA components
    before_correlations = {}
    after_correlations = {}
    
    for prop in props_to_analyze:
        if prop in properties_df.columns:
            prop_values = properties_df[prop].values
            
            before_corrs = []
            after_corrs = []
            
            for i in range(5):  # For top 5 PCA components
                before_corr = np.abs(np.corrcoef(before_pca[:, i], prop_values)[0, 1])
                after_corr = np.abs(np.corrcoef(after_pca[:, i], prop_values)[0, 1])
                
                before_corrs.append(before_corr)
                after_corrs.append(after_corr)
            
            before_correlations[prop] = before_corrs
            after_correlations[prop] = after_corrs
    
    # Create figure
    fig, axes = plt.subplots(len(props_to_analyze), 1, figsize=(12, 3*len(props_to_analyze)))
    
    for i, prop in enumerate(props_to_analyze):
        if prop in before_correlations:
            ax = axes[i] if len(props_to_analyze) > 1 else axes
            
            x = np.arange(5)
            width = 0.35
            
            ax.bar(x - width/2, before_correlations[prop], width, label='Before Training')
            ax.bar(x + width/2, after_correlations[prop], width, label='After Training')
            
            ax.set_title(f"Correlation with {prop}")
            ax.set_xlabel("PCA Component")
            ax.set_ylabel("Absolute Correlation")
            ax.set_xticks(x)
            ax.set_xticklabels([f"PC{j+1}" for j in range(5)])
            ax.legend()
    
    plt.tight_layout()
    return fig