In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import StandardScaler
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcTPSA
import os
import json
from tqdm import tqdm
from scipy.stats import pearsonr
import matplotlib.gridspec as gridspec
import argparse


class EmbeddingAnalyzer:
    """Class for analyzing molecular embeddings before and after GAN-CL training"""
    
    def __init__(self, output_dir='./analysis_results'):
        """Initialize the analyzer
        
        Args:
            output_dir: Directory to save analysis results
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Set plotting style
        sns.set_style('whitegrid')
        plt.rcParams.update({
            'font.family': 'sans-serif',
            'font.size': 12,
            'axes.labelsize': 14,
            'axes.titlesize': 16,
            'xtick.labelsize': 12,
            'ytick.labelsize': 12,
            'legend.fontsize': 12
        })
        
        # Set colors for before/after
        self.colors = {
            'before': '#1f77b4',  # blue
            'after': '#ff7f0e'    # orange
        }
    
    def load_embeddings(self, before_path, after_path):
        """Load embeddings and SMILES from .npz files
        
        Args:
            before_path: Path to before-training embeddings
            after_path: Path to after-training embeddings
            
        Returns:
            Tuple of (before_embeddings, before_smiles, after_embeddings, after_smiles)
        """
        # Load before-training data
        before_data = np.load(before_path, allow_pickle=True)
        before_embeddings = before_data['embeddings']
        before_smiles = before_data['smiles']
        
        # Load after-training data
        after_data = np.load(after_path, allow_pickle=True)
        after_embeddings = after_data['embeddings']
        after_smiles = after_data['smiles']
        
        print(f"Loaded {len(before_smiles)} molecules from before-training embeddings")
        print(f"Loaded {len(after_smiles)} molecules from after-training embeddings")
        
        return before_embeddings, before_smiles, after_embeddings, after_smiles
    
    def extract_molecular_properties(self, smiles_list):
        """Extract molecular properties from SMILES strings
        
        Args:
            smiles_list: List of SMILES strings
            
        Returns:
            DataFrame with molecular properties
        """
        print("Extracting molecular properties...")
        properties = []
        valid_smiles = []
        
        for smiles in tqdm(smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            valid_smiles.append(smiles)
            properties.append({
                'MW': Descriptors.ExactMolWt(mol),
                'LogP': Descriptors.MolLogP(mol),
                'TPSA': CalcTPSA(mol),
                'NumHAcceptors': Lipinski.NumHAcceptors(mol),
                'NumHDonors': Lipinski.NumHDonors(mol),
                'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
                'NumRings': mol.GetRingInfo().NumRings(),
                'NumHeavyAtoms': mol.GetNumHeavyAtoms(),
                'NumAtoms': mol.GetNumAtoms(),
                'NumBonds': mol.GetNumBonds(),
                'IsAromatic': 1 if any(atom.GetIsAromatic() for atom in mol.GetAtoms()) else 0
            })
        
        properties_df = pd.DataFrame(properties, index=valid_smiles)
        return properties_df
    
    def extract_functional_groups(self, smiles_list):
        """Extract functional group features from SMILES strings
        
        Args:
            smiles_list: List of SMILES strings
            
        Returns:
            DataFrame with functional group features
        """
        print("Extracting functional group features...")
        features = []
        valid_smiles = []
        
        # Define SMARTS patterns for functional groups
        fg_patterns = {
            'Alcohol': '[OX2H]',
            'Amine': '[NX3;H2,H1,H0]',
            'Carboxyl': '[CX3](=O)[OX2H1]',
            'Carbonyl': '[CX3]=O',
            'Ether': '[OD2][!O]',
            'Ester': '[#6][CX3](=O)[OX2H0][#6]',
            'Amide': '[NX3][CX3](=[OX1])',
            'Halogen': '[F,Cl,Br,I]',
            'Nitro': '[NX3](=O)=O',
            'Nitrile': '[NX1]#[CX2]',
            'Sulfone': '[#16X4](=[OX1])(=[OX1])',
            'Sulfonamide': '[#16X4]([NX3])=O',
            'Phosphate': '[PX4](=[OX1])([OX2])([OX2])[OX2]'
        }
        
        # Compile the patterns
        compiled_patterns = {name: Chem.MolFromSmarts(smarts) for name, smarts in fg_patterns.items()}
        
        # Extract features for each molecule
        for smiles in tqdm(smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            valid_smiles.append(smiles)
            
            # Count each functional group
            fg_counts = {}
            for name, pattern in compiled_patterns.items():
                if pattern is not None:
                    matches = mol.GetSubstructMatches(pattern)
                    fg_counts[name] = len(matches)
                else:
                    fg_counts[name] = 0
            
            features.append(fg_counts)
        
        features_df = pd.DataFrame(features, index=valid_smiles)
        return features_df
    
    def extract_structural_features(self, smiles_list):
        """Extract structural features from SMILES strings
        
        Args:
            smiles_list: List of SMILES strings
            
        Returns:
            DataFrame with structural features
        """
        print("Extracting structural features...")
        features = []
        valid_smiles = []
        
        for smiles in tqdm(smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            valid_smiles.append(smiles)
            
            # Get ring information
            ri = mol.GetRingInfo()
            rings = ri.AtomRings()
            
            # Check for aromatic atoms
            aromatic_atoms = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
            
            # Check for heterocycles
            heterocycles = 0
            for ring in rings:
                if any(mol.GetAtomWithIdx(idx).GetAtomicNum() != 6 for idx in ring):
                    heterocycles += 1
            
            # Check for fused rings
            fused_rings = 0
            if len(rings) >= 2:
                for i in range(len(rings)):
                    for j in range(i+1, len(rings)):
                        if len(set(rings[i]).intersection(set(rings[j]))) > 1:
                            fused_rings += 1
            
            # Check for spiro rings
            spiro_rings = 0
            if len(rings) >= 2:
                for i in range(len(rings)):
                    for j in range(i+1, len(rings)):
                        if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                            spiro_rings += 1
            
            # Check for bridged rings (simplified)
            bridged_pattern = Chem.MolFromSmarts('[D4R]')
            bridged_rings = 1 if bridged_pattern and mol.HasSubstructMatch(bridged_pattern) else 0
            
            # Check for macrocycles (ring size >= 8)
            macrocycles = 0
            for ring in rings:
                if len(ring) >= 8:
                    macrocycles += 1
            
            # Linear chain (no branching)
            linear_chain = 1 if mol.GetNumBonds() == mol.GetNumAtoms() - 1 and len(rings) == 0 else 0
            
            # Branching (atoms with > 2 neighbors)
            branched = sum(1 for atom in mol.GetAtoms() if atom.GetDegree() > 2)
            
            features.append({
                'IsAromatic': 1 if aromatic_atoms > 0 else 0,
                'NumAromaticAtoms': aromatic_atoms,
                'NumHeterocycles': heterocycles,
                'NumFusedRings': fused_rings,
                'NumSpiroRings': spiro_rings,
                'HasBridgedRings': bridged_rings,
                'NumMacrocycles': macrocycles,
                'IsLinearChain': linear_chain,
                'NumBranches': branched,
                'NumRings': len(rings)
            })
        
        features_df = pd.DataFrame(features, index=valid_smiles)
        return features_df
    
    def property_prediction_analysis(self, before_emb, after_emb, properties_df, 
                                    props_to_analyze=None):
        """Analyze how well embeddings predict molecular properties
        
        Args:
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            properties_df: DataFrame with molecular properties
            props_to_analyze: List of properties to analyze (default: None = all numeric)
            
        Returns:
            DataFrame with prediction performance metrics
        """
        print("Performing property prediction analysis...")
        
        # Get properties to analyze
        if props_to_analyze is None:
            props_to_analyze = []
            for col in properties_df.columns:
                if pd.api.types.is_numeric_dtype(properties_df[col]):
                    props_to_analyze.append(col)
        
        # Standardize embeddings
        scaler_before = StandardScaler()
        scaler_after = StandardScaler()
        
        before_emb_scaled = scaler_before.fit_transform(before_emb)
        after_emb_scaled = scaler_after.fit_transform(after_emb)
        
        # Initialize results
        results = []
        
        # For each property, train models and evaluate
        for prop in props_to_analyze:
            print(f"Analyzing property: {prop}")
            
            if prop in properties_df.columns:
                y = properties_df[prop].values
                
                # Ridge regression with cross-validation
                model = Ridge(alpha=1.0)
                
                # Evaluate on before embeddings
                before_scores = cross_val_score(model, before_emb_scaled, y, 
                                              cv=5, scoring='r2')
                before_r2 = before_scores.mean()
                
                # Evaluate on after embeddings
                after_scores = cross_val_score(model, after_emb_scaled, y, 
                                             cv=5, scoring='r2')
                after_r2 = after_scores.mean()
                
                results.append({
                    'Property': prop,
                    'Before_R2': before_r2,
                    'After_R2': after_r2,
                    'Change': after_r2 - before_r2,
                    'Percent_Change': (after_r2 - before_r2) / (abs(before_r2) + 1e-10) * 100
                })
        
        # Create DataFrame and sort by absolute change
        results_df = pd.DataFrame(results)
        results_df['Abs_Change'] = results_df['Change'].abs()
        results_df = results_df.sort_values('Abs_Change', ascending=False)
        
        return results_df
    
    def embedding_sensitivity_analysis(self, before_emb, after_emb, properties_df, 
                                      props_to_analyze=None):
        """Analyze how sensitive embeddings are to different properties
        
        Args:
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            properties_df: DataFrame with molecular properties
            props_to_analyze: List of properties to analyze (default: None = all numeric)
            
        Returns:
            DataFrame with sensitivity metrics
        """
        print("Performing embedding sensitivity analysis...")
        
        # Get properties to analyze
        if props_to_analyze is None:
            props_to_analyze = []
            for col in properties_df.columns:
                if pd.api.types.is_numeric_dtype(properties_df[col]):
                    props_to_analyze.append(col)
        
        # Standardize embeddings
        scaler_before = StandardScaler()
        scaler_after = StandardScaler()
        
        before_emb_scaled = scaler_before.fit_transform(before_emb)
        after_emb_scaled = scaler_after.fit_transform(after_emb)
        
        # Initialize results
        results = []
        
        # For each property, analyze sensitivity
        for prop in props_to_analyze:
            print(f"Analyzing sensitivity to: {prop}")
            
            if prop in properties_df.columns:
                values = properties_df[prop].values
                
                # For each molecule, find molecules with similar and different property values
                similar_dists_before = []
                different_dists_before = []
                similar_dists_after = []
                different_dists_after = []
                
                # Define similarity threshold (10% of range)
                prop_range = np.max(values) - np.min(values)
                threshold = 0.1 * prop_range
                
                for i in range(len(values)):
                    for j in range(i+1, len(values)):
                        # Check if property values are similar
                        is_similar = abs(values[i] - values[j]) < threshold
                        
                        # Calculate distances in embedding spaces
                        before_dist = np.linalg.norm(before_emb_scaled[i] - before_emb_scaled[j])
                        after_dist = np.linalg.norm(after_emb_scaled[i] - after_emb_scaled[j])
                        
                        if is_similar:
                            similar_dists_before.append(before_dist)
                            similar_dists_after.append(after_dist)
                        else:
                            different_dists_before.append(before_dist)
                            different_dists_after.append(after_dist)
                
                # Calculate average distances
                avg_similar_before = np.mean(similar_dists_before) if similar_dists_before else np.nan
                avg_different_before = np.mean(different_dists_before) if different_dists_before else np.nan
                avg_similar_after = np.mean(similar_dists_after) if similar_dists_after else np.nan
                avg_different_after = np.mean(different_dists_after) if different_dists_after else np.nan
                
                # Calculate sensitivity ratios
                sensitivity_before = avg_different_before / avg_similar_before if avg_similar_before else np.nan
                sensitivity_after = avg_different_after / avg_similar_after if avg_similar_after else np.nan
                
                results.append({
                    'Property': prop,
                    'Sensitivity_Before': sensitivity_before,
                    'Sensitivity_After': sensitivity_after,
                    'Change': sensitivity_after - sensitivity_before,
                    'Percent_Change': (sensitivity_after - sensitivity_before) / (abs(sensitivity_before) + 1e-10) * 100
                })
        
        # Create DataFrame and sort by absolute change
        results_df = pd.DataFrame(results)
        results_df['Abs_Change'] = results_df['Change'].abs()
        results_df = results_df.sort_values('Abs_Change', ascending=False)
        
        return results_df
    
    def feature_importance_analysis(self, before_emb, after_emb, features_df):
        """Analyze which molecular features are most important for embeddings
        
        Args:
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            features_df: DataFrame with molecular features (binary or counts)
            
        Returns:
            DataFrame with feature importance scores
        """
        print("Performing feature importance analysis...")
        
        # Standardize embeddings
        scaler_before = StandardScaler()
        scaler_after = StandardScaler()
        
        before_emb_scaled = scaler_before.fit_transform(before_emb)
        after_emb_scaled = scaler_after.fit_transform(after_emb)
        
        # Initialize importance scores
        before_importance = np.zeros(len(features_df.columns))
        after_importance = np.zeros(len(features_df.columns))
        
        # For a subset of dimensions in the embeddings, train models and get feature importance
        n_dims = min(5, before_emb.shape[1])  # Use top 5 dimensions or fewer
        
        for i in range(n_dims):
            print(f"Analyzing dimension {i+1}/{n_dims}...")
            
            # Target is a single embedding dimension
            y_before = before_emb_scaled[:, i]
            y_after = after_emb_scaled[:, i]
            
            # Train random forest model for before embeddings
            model_before = RandomForestRegressor(n_estimators=100, random_state=42)
            model_before.fit(features_df.values, y_before)
            before_importance += model_before.feature_importances_
            
            # Train random forest model for after embeddings
            model_after = RandomForestRegressor(n_estimators=100, random_state=42)
            model_after.fit(features_df.values, y_after)
            after_importance += model_after.feature_importances_
        
        # Average importance scores across dimensions
        before_importance /= n_dims
        after_importance /= n_dims
        
        # Create DataFrame with results
        results = pd.DataFrame({
            'Feature': features_df.columns,
            'Importance_Before': before_importance,
            'Importance_After': after_importance,
            'Change': after_importance - before_importance,
            'Percent_Change': (after_importance - before_importance) / (before_importance + 1e-10) * 100
        })
        
        # Sort by absolute change
        results['Abs_Change'] = results['Change'].abs()
        results = results.sort_values('Abs_Change', ascending=False)
        
        return results
    
    def visualize_property_prediction(self, prediction_results):
        """Visualize property prediction performance
        
        Args:
            prediction_results: DataFrame with prediction performance metrics
            
        Returns:
            Path to saved figure
        """
        print("Visualizing property prediction results...")
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 7))
        
        # Get data from results
        properties = prediction_results['Property']
        before_scores = prediction_results['Before_R2']
        after_scores = prediction_results['After_R2']
        changes = prediction_results['Change']
        
        # Set up bar positions
        x = np.arange(len(properties))
        width = 0.35
        
        # Create bars
        rects1 = ax.bar(x - width/2, before_scores, width, 
                      label='Before Training', color=self.colors['before'], alpha=0.7)
        rects2 = ax.bar(x + width/2, after_scores, width, 
                      label='After Training', color=self.colors['after'], alpha=0.7)
        
        # Add details
        ax.set_xlabel('Molecular Property')
        ax.set_ylabel('R² Score (Higher = Better)')
        ax.set_title('How Well Embeddings Capture Molecular Properties')
        ax.set_xticks(x)
        ax.set_xticklabels(properties, rotation=45, ha='right')
        ax.legend()
        
        # Add value labels and change indicators
        for i, (before, after, change) in enumerate(zip(before_scores, after_scores, changes)):
            # Add value labels
            ax.annotate(f'{before:.2f}', xy=(i - width/2, before + 0.01), 
                      ha='center', va='bottom')
            ax.annotate(f'{after:.2f}', xy=(i + width/2, after + 0.01), 
                      ha='center', va='bottom')
            
            # Add change arrow
            if abs(change) > 0.01:  # Only show meaningful changes
                color = 'green' if change > 0 else 'red'
                style = 'solid' if change > 0 else 'dashed'
                ax.annotate('', xy=(i + width/2, after), xytext=(i - width/2, before),
                          arrowprops=dict(arrowstyle='->', color=color, linestyle=style))
                # Add percentage change
                percent = change * 100 / max(0.01, abs(before))  # Avoid division by zero
                ax.annotate(f'{percent:+.1f}%', 
                          xy=(i, (before + after)/2 + 0.05),
                          ha='center', va='bottom', color=color, fontweight='bold')
        
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, 'property_prediction.png')
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        return fig_path
    
    def visualize_embedding_sensitivity(self, sensitivity_results):
        """Visualize embedding sensitivity to properties
        
        Args:
            sensitivity_results: DataFrame with sensitivity metrics
            
        Returns:
            Path to saved figure
        """
        print("Visualizing embedding sensitivity results...")
        
        # Filter out NaN results
        filtered_results = sensitivity_results.dropna(subset=['Sensitivity_Before', 'Sensitivity_After'])
        
        if len(filtered_results) == 0:
            print("No valid sensitivity results to visualize")
            return None
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 7))
        
        # Get data from results
        properties = filtered_results['Property']
        before_sens = filtered_results['Sensitivity_Before']
        after_sens = filtered_results['Sensitivity_After']
        changes = filtered_results['Change']
        
        # Set up bar positions
        x = np.arange(len(properties))
        width = 0.35
        
        # Create bars
        rects1 = ax.bar(x - width/2, before_sens, width, 
                      label='Before Training', color=self.colors['before'], alpha=0.7)
        rects2 = ax.bar(x + width/2, after_sens, width, 
                      label='After Training', color=self.colors['after'], alpha=0.7)
        
        # Add details
        ax.set_xlabel('Molecular Property')
        ax.set_ylabel('Sensitivity Ratio (Higher = Better Separation)')
        ax.set_title('How Well Embeddings Separate Molecules by Properties')
        ax.set_xticks(x)
        ax.set_xticklabels(properties, rotation=45, ha='right')
        ax.legend()
        
        # Add value labels and change indicators
        for i, (before, after, change) in enumerate(zip(before_sens, after_sens, changes)):
            # Add value labels
            ax.annotate(f'{before:.2f}', xy=(i - width/2, before + 0.05), 
                      ha='center', va='bottom')
            ax.annotate(f'{after:.2f}', xy=(i + width/2, after + 0.05), 
                      ha='center', va='bottom')
            
            # Add change arrow
            if abs(change) > 0.01:  # Only show meaningful changes
                color = 'green' if change > 0 else 'red'
                style = 'solid' if change > 0 else 'dashed'
                ax.annotate('', xy=(i + width/2, after), xytext=(i - width/2, before),
                          arrowprops=dict(arrowstyle='->', color=color, linestyle=style))
                # Add percentage change
                percent = change * 100 / max(0.01, abs(before))  # Avoid division by zero
                ax.annotate(f'{percent:+.1f}%', 
                          xy=(i, (before + after)/2 + 0.1),
                          ha='center', va='bottom', color=color, fontweight='bold')
        
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, 'embedding_sensitivity.png')
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        return fig_path
    
    def visualize_feature_importance(self, importance_results, top_n=15):
        """Visualize feature importance for embeddings
        
        Args:
            importance_results: DataFrame with feature importance metrics
            top_n: Number of top features to show
            
        Returns:
            Path to saved figure
        """
        print("Visualizing feature importance results...")
        
        # Get top N features by absolute change
        top_features = importance_results.head(top_n).copy()
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Sort by before importance
        top_features = top_features.sort_values('Importance_Before')
        
        # Get data
        features = top_features['Feature']
        before_imp = top_features['Importance_Before']
        after_imp = top_features['Importance_After']
        changes = top_features['Change']
        
        # Set up bar positions
        y = np.arange(len(features))
        height = 0.35
        
        # Create bars
        ax.barh(y - height/2, before_imp, height, 
              label='Before Training', color=self.colors['before'], alpha=0.7)
        ax.barh(y + height/2, after_imp, height, 
              label='After Training', color=self.colors['after'], alpha=0.7)
        
        # Add details
        ax.set_xlabel('Relative Importance')
        ax.set_title('Importance of Molecular Features in Embedding Space')
        ax.set_yticks(y)
        ax.set_yticklabels(features)
        ax.legend()
        
        # Add change indicators
        for i, (before, after, change) in enumerate(zip(before_imp, after_imp, changes)):
            # Add value labels
            ax.annotate(f'{before:.3f}', xy=(before + 0.002, i - height/2), 
                      ha='left', va='center')
            ax.annotate(f'{after:.3f}', xy=(after + 0.002, i + height/2), 
                      ha='left', va='center')
            
            # Add change arrow
            if abs(change) > 0.005:  # Only show meaningful changes
                color = 'green' if change > 0 else 'red'
                style = 'solid' if change > 0 else 'dashed'
                ax.annotate('', xy=(after, i + height/2), xytext=(before, i - height/2),
                          arrowprops=dict(arrowstyle='->', color=color, linestyle=style))
                
                # Add percentage change
                percent = change * 100 / max(0.01, before)  # Avoid division by zero
                ax.annotate(f'{percent:+.1f}%', 
                          xy=((before + after)/2, i),
                          ha='center', va='center', color=color, fontweight='bold')
        
        plt.tight_layout()
        
        # Save figure
        fig_path = os.path.join(self.output_dir, 'feature_importance.png')
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        
        return fig_path
    
    def visualize_embedding_space(self, before_emb, after_emb, properties_df, 
                                prop_to_color='LogP'):
        """Visualize embedding spaces colored by property
        
        Args:
            before_emb: Embeddings before training
            after_emb: Embeddings after training
            properties_df: DataFrame with properties
            prop_to_color: Property to use for coloring points
            
        Returns:
            Path to saved figure
        """
        print(f"Visualizing embedding spaces colored by {prop_to_color}...")
        
        # Apply dimensionality reduction
        pca = PCA(n_components=2)
        before_pca = pca.fit_transform(before_emb)
        pca = PCA(n_components=2)
        after_pca = pca.fit_transform(after_emb)
        
        # Get property values
        if prop_to_color in properties_df.columns:
            colors = properties_df[prop_to_color].values
            
            # Create figure
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
            
            # Before embedding
            scatter1 = ax1.scatter(before_pca[:, 0], before_pca[:, 1], c=colors, 
                                 cmap='viridis', alpha=0.8, s=50)
            ax1.set_title(f'Before Training Embedding Space\nColored by {prop_to_color}')
            ax1.set_xlabel('Principal Component 1')
            ax1.set_ylabel('Principal Component 2')
            
            # After embedding
            scatter2 = ax2.scatter(after_pca[:, 0], after_pca[:, 1], c=colors, 
                                 cmap='viridis', alpha=0.8, s=50)
            ax2.set_title(f'After Training Embedding Space\nColored by {prop_to_color}')
            ax2.set_xlabel('Principal Component 1')
            ax2.set_ylabel('Principal Component 2')
            
            # Add colorbar
            cbar = plt.colorbar(scatter1, ax=[ax1, ax2])
            cbar.set_label(prop_to_color)
            
            plt.tight_layout()
            
            # Save figure
            fig_path = os.path.join(self.output_dir, f'embedding_space_{prop_to_color}.png')
            plt.savefig(fig_path, dpi=300, bbox_inches='tight')
            
            return fig_path
        
        return None

    def run_complete_analysis(self, before_path, after_path):
        """Run a complete analysis workflow on embedding files
        
        Args:
            before_path: Path to before-training embeddings
            after_path: Path to after-training embeddings
            
        Returns:
            Dictionary with results and paths to saved figures
        """
        results = {}
        
        # Load embeddings
        before_emb, before_smiles, after_emb, after_smiles = self.load_embeddings(before_path, after_path)
        
        # Make sure embeddings have the same molecules (by index) for proper comparison
        # Get common SMILES
        common_smiles = list(set(before_smiles).intersection(set(after_smiles)))
        
        if len(common_smiles) == 0:
            print("No common molecules found in before and after embeddings!")
            return None
        
        print(f"Found {len(common_smiles)} common molecules for analysis")
        
        # Get indices for common SMILES
        before_indices = [i for i, s in enumerate(before_smiles) if s in common_smiles]
        after_indices = [i for i, s in enumerate(after_smiles) if s in common_smiles]
        
        # Filter embeddings
        before_emb_common = before_emb[before_indices]
        after_emb_common = after_emb[after_indices]
        common_smiles = [before_smiles[i] for i in before_indices]  # Keep order consistent
        
        # Extract properties and features
        properties_df = self.extract_molecular_properties(common_smiles)
        functional_groups_df = self.extract_functional_groups(common_smiles)
        structural_features_df = self.extract_structural_features(common_smiles)
        
        # Save the extracted data
        properties_df.to_csv(os.path.join(self.output_dir, 'properties.csv'))
        functional_groups_df.to_csv(os.path.join(self.output_dir, 'functional_groups.csv'))
        structural_features_df.to_csv(os.path.join(self.output_dir, 'structural_features.csv'))
        
        # Run property prediction analysis
        prop_prediction = self.property_prediction_analysis(before_emb_common, after_emb_common, properties_df)
        prop_prediction.to_csv(os.path.join(self.output_dir, 'property_prediction_results.csv'))
        results['property_prediction'] = prop_prediction
        
        # Run embedding sensitivity analysis
        sensitivity = self.embedding_sensitivity_analysis(before_emb_common, after_emb_common, properties_df)
        sensitivity.to_csv(os.path.join(self.output_dir, 'sensitivity_results.csv'))
        results['sensitivity'] = sensitivity
        
        # Run feature importance analysis for functional groups
        func_importance = self.feature_importance_analysis(before_emb_common, after_emb_common, functional_groups_df)
        func_importance.to_csv(os.path.join(self.output_dir, 'functional_group_importance.csv'))
        results['functional_importance'] = func_importance
        
        # Run feature importance analysis for structural features
        struct_importance = self.feature_importance_analysis(before_emb_common, after_emb_common, structural_features_df)
        struct_importance.to_csv(os.path.join(self.output_dir, 'structural_feature_importance.csv'))
        results['structural_importance'] = struct_importance
        
        # Create visualizations
        results['figures'] = {}
        
        # Property prediction visualization
        prop_pred_fig = self.visualize_property_prediction(prop_prediction)
        results['figures']['property_prediction'] = prop_pred_fig
        
        # Embedding sensitivity visualization
        sensitivity_fig = self.visualize_embedding_sensitivity(sensitivity)
        results['figures']['sensitivity'] = sensitivity_fig
        
        # Feature importance visualizations
        func_imp_fig = self.visualize_feature_importance(func_importance, top_n=10)
        results['figures']['functional_importance'] = func_imp_fig
        
        struct_imp_fig = self.visualize_feature_importance(struct_importance, top_n=10)
        results['figures']['structural_importance'] = struct_imp_fig
        
        # Embedding space visualizations for key properties
        for prop in ['LogP', 'MW', 'TPSA']:
            if prop in properties_df.columns:
                emb_space_fig = self.visualize_embedding_space(before_emb_common, after_emb_common, 
                                                              properties_df, prop)
                results['figures'][f'embedding_space_{prop}'] = emb_space_fig
        
        print("Analysis complete! Results saved to:", self.output_dir)
        return results

def main():
    """Main function to analyze embeddings"""
    # Set file paths directly in the code
    before_path = "./embeddings/before_training_20250301_193202.npz"
    after_path = "./embeddings/after_training_20250301_193202.npz"
    output_dir = "./analysis_results"
    
    print(f"Analyzing embeddings:")
    print(f"  - Before: {before_path}")
    print(f"  - After: {after_path}")
    print(f"  - Output: {output_dir}")
    
    # Create analyzer and run analysis
    analyzer = EmbeddingAnalyzer(output_dir=output_dir)
    results = analyzer.run_complete_analysis(before_path, after_path)
    
    if results:
        print("\nAnalysis Summary:")
        print("------------------")
        
        # Print top improved properties
        improved_props = results['property_prediction'][results['property_prediction']['Change'] > 0]
        if not improved_props.empty:
            print("\nProperties better captured after training:")
            for idx, row in improved_props.iterrows():
                print(f"  - {row['Property']}: {row['Percent_Change']:.1f}% improvement in R²")
        
        # Print top degraded properties
        degraded_props = results['property_prediction'][results['property_prediction']['Change'] < 0]
        if not degraded_props.empty:
            print("\nProperties less well captured after training:")
            for idx, row in degraded_props.iterrows():
                print(f"  - {row['Property']}: {row['Percent_Change']:.1f}% decrease in R²")
        
        # Print top features with increased importance
        increased_features = results['functional_importance'][results['functional_importance']['Change'] > 0].head(5)
        if not increased_features.empty:
            print("\nFunctional groups with increased importance:")
            for idx, row in increased_features.iterrows():
                print(f"  - {row['Feature']}: {row['Percent_Change']:.1f}% increase in importance")
        
        # Print top features with decreased importance
        decreased_features = results['functional_importance'][results['functional_importance']['Change'] < 0].head(5)
        if not decreased_features.empty:
            print("\nFunctional groups with decreased importance:")
            for idx, row in decreased_features.iterrows():
                print(f"  - {row['Feature']}: {row['Percent_Change']:.1f}% decrease in importance")
        
        print("\nVisualization files:")
        for name, path in results['figures'].items():
            if path:
                print(f"  - {name}: {path}")
    else:
        print("Analysis failed!")
    
    return results


if __name__ == "__main__":
    main()

Analyzing embeddings:
  - Before: ./embeddings/before_training_20250301_193202.npz
  - After: ./embeddings/after_training_20250301_193202.npz
  - Output: ./analysis_results
Loaded 16 molecules from before-training embeddings
Loaded 16 molecules from after-training embeddings
Found 1 common molecules for analysis
Extracting molecular properties...


  0%|                                                                                           | 0/16 [00:00<?, ?it/s][20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around p

Extracting functional group features...


  0%|                                                                                           | 0/16 [00:00<?, ?it/s][20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around p

Extracting structural features...


  0%|                                                                                           | 0/16 [00:00<?, ?it/s][20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around position 1:
[20:03:55] unknown
[20:03:55] ^
[20:03:55] SMILES Parse Error: Failed parsing SMILES 'unknown' for input: 'unknown'
[20:03:55] SMILES Parse Error: syntax error while parsing: unknown
[20:03:55] SMILES Parse Error: check for mistakes around p

Performing property prediction analysis...





KeyError: 'Change'