In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import r2_score
from sklearn.model_selection import cross_val_score, KFold
import umap

class EmbeddingAnalyzer:
    """Analyze molecular embedding spaces for biases and representation quality"""
    
    def __init__(self, metadata_path, embedding_files):
        """
        Initialize analyzer with specific file paths
        
        Args:
            metadata_path: Path to molecule metadata pickle file
            embedding_files: Dictionary mapping embedding types to file paths
        """
        self.metadata_path = metadata_path
        self.embedding_files = embedding_files
        
        # Load metadata
        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f)
        
        # Initialize dictionaries to store embeddings
        self.embedding_data = {}
        self.results = {}
        
    def load_embeddings(self):
        """Load embedding files from provided paths"""
        print("Loading embedding files...")
        
        for emb_type, filepath in self.embedding_files.items():
            if not os.path.exists(filepath):
                print(f"Warning: File not found: {filepath}")
                continue
                
            try:
                with open(filepath, 'rb') as f:
                    data = pickle.load(f)
                
                # Handle different embedding file formats
                if 'embeddings' in data:
                    embeddings = data['embeddings']
                elif isinstance(data, dict) and 'labels' in data:
                    # Format from original save_embeddings function
                    embeddings = data['embeddings']
                else:
                    print(f"Warning: Couldn't extract embeddings from {filepath}")
                    continue
                
                self.embedding_data[emb_type] = {
                    'embeddings': embeddings,
                    'metadata': data
                }
                print(f"Loaded {emb_type} embeddings: {embeddings.shape}")
            except Exception as e:
                print(f"Error loading {filepath}: {e}")
        
        if not self.embedding_data:
            raise ValueError("No embedding files loaded!")
    
    def prepare_property_data(self):
        """Extract properties from metadata into pandas DataFrame"""
        print("Preparing property data from metadata...")

        # Initialize lists to collect property data
        property_data = []

        for mol_data in self.metadata:
            # Extract properties
            props = mol_data['properties']
            feats = mol_data['features']
            funcs = mol_data['functional_groups']
            rings = mol_data['ring_info']

            # Combine all properties into a single dict
            mol_props = {
                'graph_id': mol_data.get('graph_id', 'unknown')
            }

            # Add molecular properties
            for k, v in props.items():
                if isinstance(v, (int, float, bool)):
                    mol_props[f'prop_{k}'] = v

            # Add structural features
            for k, v in feats.items():
                if isinstance(v, (int, float, bool)):
                    mol_props[f'feat_{k}'] = v

            # Add functional groups
            for k, v in funcs.items():
                if isinstance(v, (int, float, bool)):
                    mol_props[f'func_{k}'] = v

            # Add ring information
            for category, counts in rings.items():
                if isinstance(counts, dict):
                    for k, v in counts.items():
                        if isinstance(v, (int, float)):
                            mol_props[f'ring_{category}_{k}'] = v

            property_data.append(mol_props)

        # Convert to DataFrame
        self.prop_df = pd.DataFrame(property_data)

        # Make sure we have some numeric columns
        numeric_cols = self.prop_df.select_dtypes(include=np.number).columns
        if len(numeric_cols) == 0:
            print("Warning: No numeric properties found in metadata!")
            # Add a dummy numeric column to prevent errors
            self.prop_df['dummy'] = 0

        print(f"Prepared DataFrame with {len(self.prop_df)} molecules and {self.prop_df.shape[1]} properties")

        # Filter out columns with too many missing values or zero variance
        self._clean_property_dataframe()
    
    def _clean_property_dataframe(self):
        """Clean property DataFrame by removing low-information columns"""
        # Exclude non-numeric columns from variance calculation
        numeric_cols = self.prop_df.select_dtypes(include=np.number).columns

        # Remove columns with too many missing values
        missing_thresh = 0.5
        missing_cols = [col for col in numeric_cols 
                        if self.prop_df[col].isna().mean() > missing_thresh]

        # Remove columns with zero variance
        var_thresh = 0.0
        var_cols = [col for col in numeric_cols 
                    if col not in missing_cols and self.prop_df[col].var() <= var_thresh]

        # Remove identified columns
        drop_cols = missing_cols + var_cols
        if drop_cols:
            self.prop_df = self.prop_df.drop(columns=drop_cols)
            print(f"Removed {len(drop_cols)} low-information columns")

        # Fill remaining missing values
        numeric_cols = self.prop_df.select_dtypes(include=np.number).columns
        self.prop_df[numeric_cols] = self.prop_df[numeric_cols].fillna(0)
    
    def analyze_property_prediction(self, properties=None, cv=5):
        """
        Analyze how well embeddings predict molecular properties

        Args:
            properties: List of property column names to analyze (default: analyze all numeric)
            cv: Number of cross-validation folds
        """
        print("\n=== Property Prediction Analysis ===")

        # Select properties to analyze
        if properties is None:
            # Only consider numeric columns, and exclude graph_id
            numeric_cols = self.prop_df.select_dtypes(include=np.number).columns
            properties = list(numeric_cols)

        if len(properties) == 0:
            print("No numeric properties available for prediction analysis.")
            return pd.DataFrame()

        # Initialize results storage
        property_results = {'property': [], 'embedding_type': [], 'r2_score': [], 'model': []}

        # Iterate through available embedding types
        for emb_type, emb_data in self.embedding_data.items():
            print(f"\nAnalyzing {emb_type} embeddings...")
            embeddings = emb_data['embeddings']

            # Analyze each property
            for prop in properties:
                y = self.prop_df[prop].values

                # Skip if property has too little variance
                if np.var(y) < 1e-6:
                    continue

                # Try both linear and non-linear models
                for model_type, model in [('Linear', Ridge(alpha=1.0)), 
                                         ('RandomForest', RandomForestRegressor(n_estimators=50))]:
                    # Perform cross-validation
                    try:
                        cv_scores = cross_val_score(model, embeddings, y, 
                                                  cv=KFold(n_splits=min(cv, len(y)), shuffle=True, random_state=42), 
                                                  scoring='r2')

                        # Store results
                        mean_r2 = np.mean(cv_scores)
                        property_results['property'].append(prop)
                        property_results['embedding_type'].append(emb_type)
                        property_results['r2_score'].append(mean_r2)
                        property_results['model'].append(model_type)

                        print(f"{prop} - {model_type}: R² = {mean_r2:.4f}")
                    except Exception as e:
                        print(f"Error analyzing {prop}: {e}")

        # Convert to DataFrame and store
        results_df = pd.DataFrame(property_results)
        self.results['property_prediction'] = results_df

        return results_df
    
    def analyze_embedding_sensitivity(self, properties=None, percentile_threshold=10):
        """
        Analyze embedding sensitivity to properties
        
        Args:
            properties: List of property column names to analyze (default: analyze all numeric)
            percentile_threshold: Percentile threshold for defining similar/different molecules
        """
        print("\n=== Embedding Sensitivity Analysis ===")
        
        # Select properties to analyze
        if properties is None:
            numeric_cols = self.prop_df.select_dtypes(include=np.number).columns
            properties = [col for col in numeric_cols if col != 'graph_id']
        
        # Initialize results storage
        sensitivity_results = {
            'property': [], 
            'embedding_type': [], 
            'avg_dist_similar': [],
            'avg_dist_different': [],
            'sensitivity_ratio': []
        }
        
        # Iterate through available embedding types
        for emb_type, emb_data in self.embedding_data.items():
            print(f"\nAnalyzing {emb_type} embeddings...")
            embeddings = emb_data['embeddings']
            
            # Analyze each property
            for prop in properties:
                y = self.prop_df[prop].values
                
                # Skip if property has too little variance
                if np.var(y) < 1e-6:
                    continue
                
                # Compute low and high thresholds for the property
                low_thresh = np.percentile(y, percentile_threshold)
                high_thresh = np.percentile(y, 100 - percentile_threshold)
                
                if low_thresh == high_thresh:
                    print(f"Skipping {prop} - insufficient variance")
                    continue
                
                # Define similar and different molecule pairs
                low_indices = np.where(y <= low_thresh)[0]
                high_indices = np.where(y >= high_thresh)[0]
                
                if len(low_indices) < 5 or len(high_indices) < 5:
                    print(f"Skipping {prop} - insufficient samples in partition")
                    continue
                
                # Calculate distances between similar molecules
                similar_dists = []
                
                # Sample pairs from low group
                np.random.seed(42)
                low_pairs = np.random.choice(low_indices, size=(min(1000, len(low_indices) * (len(low_indices) - 1) // 2), 2), replace=True)
                for i, j in low_pairs:
                    if i != j:
                        similar_dists.append(np.linalg.norm(embeddings[i] - embeddings[j]))
                
                # Sample pairs from high group
                high_pairs = np.random.choice(high_indices, size=(min(1000, len(high_indices) * (len(high_indices) - 1) // 2), 2), replace=True)
                for i, j in high_pairs:
                    if i != j:
                        similar_dists.append(np.linalg.norm(embeddings[i] - embeddings[j]))
                
                # Calculate distances between different molecules
                different_dists = []
                
                # Sample pairs between low and high groups
                for _ in range(min(2000, len(low_indices) * len(high_indices))):
                    i = np.random.choice(low_indices)
                    j = np.random.choice(high_indices)
                    different_dists.append(np.linalg.norm(embeddings[i] - embeddings[j]))
                
                # Compute average distances and sensitivity ratio
                avg_similar = np.mean(similar_dists)
                avg_different = np.mean(different_dists)
                sensitivity_ratio = avg_different / avg_similar if avg_similar > 0 else 0
                
                # Store results
                sensitivity_results['property'].append(prop)
                sensitivity_results['embedding_type'].append(emb_type)
                sensitivity_results['avg_dist_similar'].append(avg_similar)
                sensitivity_results['avg_dist_different'].append(avg_different)
                sensitivity_results['sensitivity_ratio'].append(sensitivity_ratio)
                
                print(f"{prop}: Ratio = {sensitivity_ratio:.4f} (Similar: {avg_similar:.4f}, Different: {avg_different:.4f})")
        
        # Convert to DataFrame and store
        results_df = pd.DataFrame(sensitivity_results)
        self.results['embedding_sensitivity'] = results_df
        
        return results_df
    
    def analyze_feature_importance(self, embedding_types=None):
        """
        Analyze importance of molecular features in determining embedding structure
        
        Args:
            embedding_types: List of embedding types to analyze (default: analyze all)
        """
        print("\n=== Feature Importance Analysis ===")
        
        # Select embedding types to analyze
        if embedding_types is None:
            embedding_types = list(self.embedding_data.keys())
        
        # Initialize results storage
        feature_results = {'feature': [], 'embedding_type': [], 'importance': []}
        
        # Get feature columns (exclude graph_id and target properties)
        feature_cols = [col for col in self.prop_df.columns if col != 'graph_id']
        X = self.prop_df[feature_cols].values
        
        # Iterate through embedding types
        for emb_type in embedding_types:
            if emb_type not in self.embedding_data:
                print(f"Warning: {emb_type} not found in embeddings. Skipping.")
                continue
                
            print(f"\nAnalyzing {emb_type} embeddings...")
            embeddings = self.embedding_data[emb_type]['embeddings']
                
            # Train a RandomForest model for each embedding dimension
            emb_dim = embeddings.shape[1]
            importance_matrix = np.zeros((len(feature_cols), emb_dim))
            
            for dim in range(emb_dim):
                # Extract target embedding dimension
                y = embeddings[:, dim]
                
                # Train model
                model = RandomForestRegressor(n_estimators=50, max_depth=6, random_state=42)
                model.fit(X, y)
                
                # Store feature importances
                importance_matrix[:, dim] = model.feature_importances_
            
            # Aggregate importance across dimensions (mean importance)
            mean_importance = np.mean(importance_matrix, axis=1)
            
            # Store results for each feature
            for i, feature in enumerate(feature_cols):
                feature_results['feature'].append(feature)
                feature_results['embedding_type'].append(emb_type)
                feature_results['importance'].append(mean_importance[i])
        
        # Convert to DataFrame and store
        results_df = pd.DataFrame(feature_results)
        self.results['feature_importance'] = results_df
        
        return results_df
    
    def visualize_embeddings(self, properties=None, method='umap', save_dir='./figures'):
        """
        Create visualizations of embedding spaces colored by properties
        
        Args:
            properties: List of properties to color points by (default: top 4 most predictable)
            method: Dimension reduction method ('pca', 'tsne', or 'umap')
            save_dir: Directory to save visualization figures
        """
        print("\n=== Embedding Visualization ===")
        os.makedirs(save_dir, exist_ok=True)
        
        # If no properties specified, use top 4 from property prediction results
        if properties is None and 'property_prediction' in self.results:
            top_props = (self.results['property_prediction']
                         .groupby('property')['r2_score']
                         .max()
                         .sort_values(ascending=False)
                         .head(4)
                         .index.tolist())
            properties = top_props
        elif properties is None:
            # Use a few common properties if available
            potential_props = ['prop_num_nodes', 'prop_avg_node_degree', 
                              'prop_clustering_coefficient', 'feat_is_connected']
            properties = [p for p in potential_props if p in self.prop_df.columns][:4]
        
        print(f"Visualizing with properties: {properties}")
        
        # Iterate through embedding types
        for emb_type, emb_data in self.embedding_data.items():
            print(f"\nVisualizing {emb_type} embeddings...")
            embeddings = emb_data['embeddings']
            
            # Apply dimension reduction
            if method == 'pca':
                reducer = PCA(n_components=2, random_state=42)
                emb_2d = reducer.fit_transform(embeddings)
                method_name = 'PCA'
            elif method == 'tsne':
                reducer = TSNE(n_components=2, random_state=42, perplexity=30)
                emb_2d = reducer.fit_transform(embeddings)
                method_name = 't-SNE'
            else:  # umap
                reducer = umap.UMAP(n_components=2, random_state=42)
                emb_2d = reducer.fit_transform(embeddings)
                method_name = 'UMAP'
            
            # Create DataFrame for visualization
            vis_df = pd.DataFrame({
                'x': emb_2d[:, 0],
                'y': emb_2d[:, 1]
            })
            
            # Create plots for each property
            for prop in properties:
                if prop not in self.prop_df.columns:
                    print(f"Warning: Property {prop} not found in data. Skipping.")
                    continue
                
                # Add property values
                vis_df['property'] = self.prop_df[prop].values
                
                # Create figure
                plt.figure(figsize=(10, 8))
                
                # Create scatter plot colored by property
                if prop.startswith('feat_') and np.all(np.isin(vis_df['property'].unique(), [0, 1])):
                    # Categorical (boolean) property
                    sns.scatterplot(data=vis_df, x='x', y='y', hue='property', 
                                   palette='viridis', alpha=0.7)
                else:
                    # Continuous property
                    plt.scatter(vis_df['x'], vis_df['y'], c=vis_df['property'], 
                               cmap='viridis', alpha=0.7)
                    plt.colorbar(label=prop)
                
                # Set labels and title
                plt.xlabel(f'{method_name} Dimension 1')
                plt.ylabel(f'{method_name} Dimension 2')
                plt.title(f'{emb_type} Embeddings - Colored by {prop}')
                
                # Save figure
                plt.tight_layout()
                filename = f'{emb_type}_{method}_{prop}.png'
                plt.savefig(os.path.join(save_dir, filename), dpi=300)
                plt.close()
                
        print(f"Saved visualization figures to {save_dir}")
    
    def compare_property_prediction(self):
        """Compare property prediction quality between pre and post training"""
        if 'property_prediction' not in self.results:
            print("Run analyze_property_prediction first!")
            return
        
        results = self.results['property_prediction']
        
        # Filter to just pre and post training (not intermediate epochs)
        compare_df = results[results['embedding_type'].isin(['pre_training', 'post_training'])]
        
        # Pivot to get pre-post comparison
        pivot_df = compare_df.pivot_table(
            index=['property', 'model'], 
            columns='embedding_type', 
            values='r2_score'
        ).reset_index()
        
        # Calculate improvement
        if 'pre_training' in pivot_df.columns and 'post_training' in pivot_df.columns:
            pivot_df['improvement'] = pivot_df['post_training'] - pivot_df['pre_training']
            pivot_df['relative_improvement'] = pivot_df['improvement'] / (pivot_df['pre_training'].abs() + 1e-6)
            
            # Sort by absolute improvement
            pivot_df = pivot_df.sort_values('improvement', ascending=False)
            
            # Print results
            print("\n=== Property Prediction Comparison ===")
            print(pivot_df)
            
            # Visualize improvements
            plt.figure(figsize=(12, 8))
            
            # Filter to just RandomForest results for better visualization
            rf_df = pivot_df[pivot_df['model'] == 'RandomForest']
            
            # Get top 15 by absolute change for plotting
            plot_df = rf_df.head(15)
            
            # Create bar chart
            bars = plt.barh(plot_df['property'], plot_df['improvement'], color='skyblue')
            
            # Highlight negative improvements in red
            for i, imp in enumerate(plot_df['improvement']):
                if imp < 0:
                    bars[i].set_color('salmon')
            
            plt.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            plt.xlabel('Improvement in R² Score (Post - Pre)')
            plt.title('Changes in Property Predictability After Training')
            plt.tight_layout()
            
            # Save figure
            plt.savefig('property_prediction_improvement.png', dpi=300)
            plt.close()
            
            return pivot_df
        else:
            print("Missing either pre-training or post-training embeddings")
            return None
    
    def compare_feature_importance(self):
        """Compare feature importance between pre and post training"""
        if 'feature_importance' not in self.results:
            print("Run analyze_feature_importance first!")
            return
        
        results = self.results['feature_importance']
        
        # Filter to just pre and post training (not intermediate epochs)
        compare_df = results[results['embedding_type'].isin(['pre_training', 'post_training'])]
        
        # Pivot to get pre-post comparison
        pivot_df = compare_df.pivot_table(
            index='feature', 
            columns='embedding_type', 
            values='importance'
        ).reset_index()
        
        # Calculate changes
        if 'pre_training' in pivot_df.columns and 'post_training' in pivot_df.columns:
            pivot_df['abs_change'] = pivot_df['post_training'] - pivot_df['pre_training']
            pivot_df['rel_change'] = pivot_df['abs_change'] / (pivot_df['pre_training'] + 1e-6)
            
            # Sort by absolute change
            pivot_df = pivot_df.sort_values('abs_change', ascending=False)
            
            # Print results
            print("\n=== Feature Importance Comparison ===")
            print(pivot_df)
            
            # Visualize changes
            plt.figure(figsize=(12, 10))
            
            # Get top and bottom 10 by absolute change for plotting
            top_df = pivot_df.head(10)
            bottom_df = pivot_df.tail(10)
            plot_df = pd.concat([top_df, bottom_df])
            
            # Create bar chart
            bars = plt.barh(plot_df['feature'], plot_df['abs_change'], color='skyblue')
            
            # Highlight negative changes in red
            for i, change in enumerate(plot_df['abs_change']):
                if change < 0:
                    bars[i].set_color('salmon')
            
            plt.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            plt.xlabel('Change in Feature Importance (Post - Pre)')
            plt.title('Changes in Feature Importance After Training')
            plt.tight_layout()
            
            # Save figure
            plt.savefig('feature_importance_change.png', dpi=300)
            plt.close()
            
            return pivot_df
        else:
            print("Missing either pre-training or post-training embeddings")
            return None

    def compare_sensitivity(self):
        """Compare embedding sensitivity between pre and post training"""
        if 'embedding_sensitivity' not in self.results:
            print("Run analyze_embedding_sensitivity first!")
            return
        
        results = self.results['embedding_sensitivity']
        
        # Filter to just pre and post training (not intermediate epochs)
        compare_df = results[results['embedding_type'].isin(['pre_training', 'post_training'])]
        
        # Pivot to get pre-post comparison for sensitivity ratio
        pivot_df = compare_df.pivot_table(
            index='property', 
            columns='embedding_type', 
            values='sensitivity_ratio'
        ).reset_index()
        
        # Calculate changes
        if 'pre_training' in pivot_df.columns and 'post_training' in pivot_df.columns:
            pivot_df['abs_change'] = pivot_df['post_training'] - pivot_df['pre_training']
            pivot_df['rel_change'] = pivot_df['abs_change'] / (pivot_df['pre_training'] + 1e-6)
            
            # Sort by absolute change
            pivot_df = pivot_df.sort_values('abs_change', ascending=False)
            
            # Print results
            print("\n=== Sensitivity Ratio Comparison ===")
            print(pivot_df)
            
            # Visualize changes
            plt.figure(figsize=(12, 8))
            
            # Get top 15 by absolute change for plotting
            plot_df = pivot_df.head(15)
            
            # Create bar chart
            bars = plt.barh(plot_df['property'], plot_df['abs_change'], color='skyblue')
            
            # Highlight negative changes in red
            for i, change in enumerate(plot_df['abs_change']):
                if change < 0:
                    bars[i].set_color('salmon')
            
            plt.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            plt.xlabel('Change in Sensitivity Ratio (Post - Pre)')
            plt.title('Changes in Embedding Sensitivity After Training')
            plt.tight_layout()
            
            # Save figure
            plt.savefig('sensitivity_change.png', dpi=300)
            plt.close()
            
            return pivot_df
        else:
            print("Missing either pre-training or post-training embeddings")
            return None
            
    def run_all_analyses(self):
        """Run all analysis methods and generate a comprehensive report"""
        print("Starting comprehensive embedding analysis...")
        
        # Prepare data
        self.prepare_property_data()
        
        # Run all analyses
        self.analyze_property_prediction()
        self.analyze_embedding_sensitivity()
        self.analyze_feature_importance()
        
        # Generate comparison reports
        self.compare_property_prediction()
        self.compare_feature_importance()
        self.compare_sensitivity()
        
        # Create visualizations
        self.visualize_embeddings()
        
        print("\nAnalysis complete! Results and visualizations have been saved.")

In [2]:
def create_enhanced_visualizations(analyzer, output_dir='./enhanced_figures'):
    """Create enhanced visualizations for better analysis communication"""
    import os
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.decomposition import PCA
    
    os.makedirs(output_dir, exist_ok=True)
    
    # ---- 1. Property Prediction Power Chart ----
    if 'property_prediction' in analyzer.results:
        plt.figure(figsize=(12, 10))
        
        # Filter to important properties and organize by category
        pred_df = analyzer.results['property_prediction']
        
        # Filter to RandomForest results for better visualization
        rf_df = pred_df[pred_df['model'] == 'RandomForest']
        
        # Create pivot table for pre vs post
        pivot_df = rf_df.pivot_table(
            index='property', 
            columns='embedding_type', 
            values='r2_score'
        ).reset_index()
        
        # Add improvement column
        if 'pre_training' in pivot_df.columns and 'post_training' in pivot_df.columns:
            pivot_df['improvement'] = pivot_df['post_training'] - pivot_df['pre_training']
            
            # Sort by improvement
            pivot_df = pivot_df.sort_values('improvement', ascending=False)
            
            # Create a categorical column for property type
            def categorize_property(prop):
                if 'mw' in prop.lower() or 'weight' in prop.lower():
                    return 'Molecular Weight'
                elif 'logp' in prop.lower():
                    return 'Lipophilicity'
                elif 'ring' in prop.lower():
                    return 'Ring Structure'
                elif 'aromatic' in prop.lower():
                    return 'Aromaticity'
                elif 'func_' in prop.lower():
                    return 'Functional Group'
                elif 'contrib' in prop.lower():
                    return 'Atom Contribution'
                else:
                    return 'Other Properties'
            
            pivot_df['category'] = pivot_df['property'].apply(categorize_property)
            
            # Filter to top 20 properties by absolute improvement
            plot_df = pivot_df.copy()
            plot_df['abs_improvement'] = plot_df['improvement'].abs()
            plot_df = plot_df.sort_values('abs_improvement', ascending=False).head(20)
            
            # Create plot
            fig, ax = plt.subplots(figsize=(12, 10))
            
            # Plot bars
            bar_width = 0.35
            x = np.arange(len(plot_df))
            
            # Pre-training bars
            rects1 = ax.barh(x - bar_width/2, plot_df['pre_training'], bar_width, 
                             label='Pre-training', color='skyblue', alpha=0.7)
            
            # Post-training bars
            rects2 = ax.barh(x + bar_width/2, plot_df['post_training'], bar_width,
                             label='Post-training', color='orange', alpha=0.7)
            
            # Add property names
            property_labels = [p.replace('prop_', '').replace('func_', '').replace('ring_', '') 
                              for p in plot_df['property']]
            ax.set_yticks(x)
            ax.set_yticklabels(property_labels)
            
            # Add a line at R²=0
            ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            
            # Add labels and title
            ax.set_xlabel('R² Score (Higher is Better)')
            ax.set_title('Property Prediction Power: Pre vs. Post Training', fontsize=14)
            ax.legend()
            
            # Add color-coded categories (FIXED: the indexing error was here)
            categories = plot_df['category'].unique()
            for i, cat in enumerate(categories):
                # Get indices in the x array where the category matches
                idx_in_plot_df = plot_df[plot_df['category'] == cat].index
                # Map these indices to positions in the x array
                positions_in_x = [j for j, idx in enumerate(plot_df.index) if idx in idx_in_plot_df]
                
                if positions_in_x:  # If we have positions for this category
                    min_idx = min(positions_in_x)
                    max_idx = max(positions_in_x)
                    ax.axhspan(min_idx - 0.5, max_idx + 0.5, alpha=0.1, color=plt.cm.tab10(i))
                    ax.text(ax.get_xlim()[0], (min_idx + max_idx) / 2, cat, 
                            ha='left', va='center', fontsize=9, fontweight='bold')
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'property_prediction_power.png'), dpi=300)
            plt.close()
    
    # ---- 2. Property Sensitivity Heatmap ----
    if 'embedding_sensitivity' in analyzer.results:
        sens_df = analyzer.results['embedding_sensitivity']
        
        # Pivot to get property x embedding type matrix of sensitivity ratios
        pivot_sens = sens_df.pivot_table(
            index='property', 
            columns='embedding_type', 
            values='sensitivity_ratio'
        )
        
        # Add a delta column (percentage change)
        if 'pre_training' in pivot_sens.columns and 'post_training' in pivot_sens.columns:
            delta = ((pivot_sens['post_training'] - pivot_sens['pre_training']) / 
                    pivot_sens['pre_training'] * 100)
            pivot_sens['delta_pct'] = delta
            
            # Sort by delta
            pivot_sens = pivot_sens.sort_values('delta_pct', ascending=False)
            
            # Select top properties by absolute change percentage
            plot_sens = pivot_sens.head(20)
            
            # Create heatmap
            plt.figure(figsize=(12, 10))
            
            # Format for better display
            heatmap_df = plot_sens[['pre_training', 'post_training']].copy()
            # Clean property names for display
            heatmap_df.index = [idx.replace('prop_', '').replace('func_', '').replace('ring_', '') 
                              for idx in heatmap_df.index]
            
            # Create diverging colormap centered at 1.0
            cmap = sns.diverging_palette(220, 20, as_cmap=True)
            
            # Plot heatmap
            ax = sns.heatmap(heatmap_df, annot=True, fmt=".2f", cmap=cmap, center=1.0,
                           linewidths=0.5, cbar_kws={"label": "Sensitivity Ratio"})
            
            # Add delta values
            for i, idx in enumerate(plot_sens.index):
                delta_val = plot_sens.loc[idx, 'delta_pct']
                color = 'green' if delta_val > 0 else 'red'
                plt.text(2.5, i + 0.5, f"{delta_val:.1f}%", 
                        ha='center', va='center', fontweight='bold', color=color)
            
            plt.title('Property Sensitivity: Pre vs. Post Training', fontsize=14)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'property_sensitivity_heatmap.png'), dpi=300)
            plt.close()
    
    # ---- 3. Feature Importance Radar Chart ----
    if 'feature_importance' in analyzer.results:
        feat_df = analyzer.results['feature_importance']
        
        # Filter to pre and post training
        radar_df = feat_df[feat_df['embedding_type'].isin(['pre_training', 'post_training'])]
        
        # Group features by category
        def group_feature(feat):
            if 'ring' in feat:
                return 'Ring Structure'
            elif 'aromatic' in feat:
                return 'Aromaticity'
            elif 'func_' in feat:
                return 'Functional Group'
            elif any(term in feat for term in ['hybridization', 'valence', 'charge']):
                return 'Electronic Properties'
            elif any(term in feat for term in ['path', 'diameter', 'degree']):
                return 'Topological Properties'
            else:
                return 'Other'
                
        radar_df['category'] = radar_df['feature'].apply(group_feature)
        
        # Calculate average importance by category
        radar_pivot = radar_df.pivot_table(
            index='category',
            columns='embedding_type',
            values='importance',
            aggfunc='mean'
        ).reset_index()
        
        if not radar_pivot.empty and 'pre_training' in radar_pivot.columns and 'post_training' in radar_pivot.columns:
            # Create radar chart
            categories = radar_pivot['category']
            pre_values = radar_pivot['pre_training']
            post_values = radar_pivot['post_training']
            
            # Compute angles for the radar chart
            N = len(categories)
            if N > 0:  # Only proceed if we have categories
                angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
                angles += angles[:1]  # Close the loop
                
                # Add the values to complete the loop
                pre_values = pre_values.tolist()
                pre_values += pre_values[:1]
                post_values = post_values.tolist()
                post_values += post_values[:1]
                
                # Create figure
                fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(polar=True))
                
                # Plot pre-training values
                ax.plot(angles, pre_values, 'o-', linewidth=2, label='Pre-training', color='blue', alpha=0.7)
                ax.fill(angles, pre_values, alpha=0.1, color='blue')
                
                # Plot post-training values
                ax.plot(angles, post_values, 'o-', linewidth=2, label='Post-training', color='orange', alpha=0.7)
                ax.fill(angles, post_values, alpha=0.1, color='orange')
                
                # Set category labels
                categories = categories.tolist()
                categories += categories[:1]  # Complete the loop
                ax.set_xticks(angles)
                ax.set_xticklabels(categories, fontsize=10, fontweight='bold')
                
                # Add legend and title
                ax.legend(loc='upper right')
                plt.title('Feature Importance by Category: Pre vs. Post Training', fontsize=14)
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'feature_importance_radar.png'), dpi=300)
                plt.close()
    
    # ---- 4. Ring Structure Analysis ----
    # Extract ring-related properties
    ring_props = [prop for prop in analyzer.prop_df.columns if 'ring' in prop.lower()]
    
    if ring_props and 'embedding_sensitivity' in analyzer.results:
        # Filter sensitivity results to ring properties
        ring_sens = analyzer.results['embedding_sensitivity']
        ring_sens = ring_sens[ring_sens['property'].isin(ring_props)]
        
        # Create a comparison of ring sensitivity
        ring_pivot = ring_sens.pivot_table(
            index='property',
            columns='embedding_type',
            values='sensitivity_ratio'
        ).reset_index()
        
        # Add improvement column
        if 'pre_training' in ring_pivot.columns and 'post_training' in ring_pivot.columns:
            ring_pivot['pct_change'] = ((ring_pivot['post_training'] - ring_pivot['pre_training']) / 
                                      ring_pivot['pre_training'] * 100)
            
            # Categorize by ring type
            def ring_category(prop):
                if 'single' in prop:
                    return 'Single Rings'
                elif 'fused' in prop:
                    return 'Fused Rings'
                elif 'spiro' in prop:
                    return 'Spiro Rings'
                elif 'bridged' in prop:
                    return 'Bridged Rings'
                elif 'sizes' in prop:
                    # Extract ring size
                    size = prop.split('_')[-1]
                    return f'Ring Size {size}'
                else:
                    return 'General Ring Properties'
                    
            ring_pivot['category'] = ring_pivot['property'].apply(ring_category)
            
            if not ring_pivot.empty:  # Only proceed if we have data
                # Create grouped bar chart
                plt.figure(figsize=(12, 8))
                
                # Group by category and sort
                ring_pivot = ring_pivot.sort_values(['category', 'pct_change'])
                
                # Set positions and width
                pos = np.arange(len(ring_pivot))
                width = 0.35
                
                # Create bars
                fig, ax = plt.subplots(figsize=(12, 8))
                
                # Pre-training bars
                pre_bars = ax.barh(pos - width/2, ring_pivot['pre_training'], width,
                                  label='Pre-training', color='skyblue')
                
                # Post-training bars
                post_bars = ax.barh(pos + width/2, ring_pivot['post_training'], width,
                                   label='Post-training', color='orange')
                
                # Add property labels
                ax.set_yticks(pos)
                # Clean up property names for display
                ax.set_yticklabels([p.replace('ring_ring_', '') for p in ring_pivot['property']])
                
                # Add a reference line at 1.0
                ax.axvline(x=1.0, color='gray', linestyle='--', alpha=0.7)
                
                # Add labels and title
                ax.set_xlabel('Sensitivity Ratio (Higher = Better Separation)')
                ax.set_title('Ring Structure Sensitivity: Pre vs. Post Training', fontsize=14)
                ax.legend()
                
                # Add percentage change annotations
                for i, (_, row) in enumerate(ring_pivot.iterrows()):
                    pct = row['pct_change']
                    color = 'green' if pct > 0 else 'red'
                    ax.text(max(row['pre_training'], row['post_training']) + 0.05, i, 
                           f"{pct:.1f}%", va='center', color=color, fontweight='bold')
                
                # Highlight categories with background colors
                categories = ring_pivot['category'].unique()
                for i, cat in enumerate(categories):
                    # Get indices in the current plot
                    positions = [j for j, (_, row) in enumerate(ring_pivot.iterrows()) if row['category'] == cat]
                    
                    if positions:  # If we have positions for this category
                        min_pos = min(positions)
                        max_pos = max(positions)
                        ax.axhspan(min_pos - 0.5, max_pos + 0.5, color=plt.cm.Pastel1(i), alpha=0.3)
                        # Add category label
                        ax.text(ax.get_xlim()[0] - 0.1, (min_pos + max_pos) / 2, cat,
                               ha='right', va='center', fontweight='bold', fontsize=10)
                        
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'ring_structure_sensitivity.png'), dpi=300)
                plt.close()
    
    # ---- 5. Functional Group Impact ----
    # Extract functional group related properties
    func_props = [prop for prop in analyzer.prop_df.columns if 'func_' in prop]
    
    if func_props and 'embedding_sensitivity' in analyzer.results and 'feature_importance' in analyzer.results:
        # Get sensitivities for functional groups
        func_sens = analyzer.results['embedding_sensitivity']
        func_sens = func_sens[func_sens['property'].isin(func_props)]
        
        # Get importance for functional groups
        func_imp = analyzer.results['feature_importance']
        func_imp = func_imp[func_imp['feature'].isin(func_props)]
        
        # Create a combined dataframe
        func_data = []
        
        for prop in func_props:
            # Get sensitivity data
            sens_pre = func_sens[(func_sens['property'] == prop) & 
                                (func_sens['embedding_type'] == 'pre_training')]['sensitivity_ratio'].values
            sens_post = func_sens[(func_sens['property'] == prop) & 
                                 (func_sens['embedding_type'] == 'post_training')]['sensitivity_ratio'].values
            
            # Get importance data
            imp_pre = func_imp[(func_imp['feature'] == prop) & 
                              (func_imp['embedding_type'] == 'pre_training')]['importance'].values
            imp_post = func_imp[(func_imp['feature'] == prop) & 
                               (func_imp['embedding_type'] == 'post_training')]['importance'].values
            
            # Only add if we have all data
            if len(sens_pre) > 0 and len(sens_post) > 0 and len(imp_pre) > 0 and len(imp_post) > 0:
                func_data.append({
                    'property': prop,
                    'sensitivity_pre': sens_pre[0],
                    'sensitivity_post': sens_post[0],
                    'importance_pre': imp_pre[0],
                    'importance_post': imp_post[0],
                    'sensitivity_change': (sens_post[0] - sens_pre[0]) / sens_pre[0] * 100,
                    'importance_change': (imp_post[0] - imp_pre[0]) / imp_pre[0] * 100
                })
                
        # Create dataframe
        if func_data:
            func_df = pd.DataFrame(func_data)
            
            if not func_df.empty:  # Only proceed if we have data
                # Create a quad chart (importance vs sensitivity)
                fig, ax = plt.subplots(figsize=(10, 8))
                
                # Clean property names
                func_df['display_name'] = func_df['property'].str.replace('func_', '')
                
                # Create arrows from pre to post
                for _, row in func_df.iterrows():
                    ax.arrow(row['importance_pre'], row['sensitivity_pre'],
                            row['importance_post'] - row['importance_pre'],
                            row['sensitivity_post'] - row['sensitivity_pre'],
                            head_width=0.01, head_length=0.02, fc='black', ec='black', length_includes_head=True)
                    
                    # Add property label at midpoint
                    mid_x = (row['importance_pre'] + row['importance_post']) / 2
                    mid_y = (row['sensitivity_pre'] + row['sensitivity_post']) / 2
                    ax.text(mid_x, mid_y, row['display_name'], fontsize=9, 
                           ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7, boxstyle='round'))
                
                # Add quadrant lines
                imp_min = min(func_df[['importance_pre', 'importance_post']].values.flatten())
                imp_max = max(func_df[['importance_pre', 'importance_post']].values.flatten())
                sens_min = min(func_df[['sensitivity_pre', 'sensitivity_post']].values.flatten())
                sens_max = max(func_df[['sensitivity_pre', 'sensitivity_post']].values.flatten())
                
                imp_mid = (imp_min + imp_max) / 2
                sens_mid = (sens_min + sens_max) / 2
                
                ax.axhline(y=sens_mid, color='gray', linestyle='--', alpha=0.5)
                ax.axvline(x=imp_mid, color='gray', linestyle='--', alpha=0.5)
                
                # Set axis limits with some padding
                ax.set_xlim(imp_min - 0.01, imp_max + 0.01)
                ax.set_ylim(sens_min - 0.05, sens_max + 0.05)
                
                # Add quadrant labels
                ax.text(imp_max * 0.9, sens_max * 0.9, 
                       "High Importance\nHigh Sensitivity", ha='center', va='center', 
                       bbox=dict(facecolor='lightyellow', alpha=0.7))
                
                ax.text(imp_min * 1.1, sens_max * 0.9, 
                       "Low Importance\nHigh Sensitivity", ha='center', va='center',
                       bbox=dict(facecolor='lightblue', alpha=0.7))
                
                ax.text(imp_max * 0.9, sens_min * 1.1, 
                       "High Importance\nLow Sensitivity", ha='center', va='center',
                       bbox=dict(facecolor='lightgreen', alpha=0.7))
                
                ax.text(imp_min * 1.1, sens_min * 1.1, 
                       "Low Importance\nLow Sensitivity", ha='center', va='center',
                       bbox=dict(facecolor='lightgray', alpha=0.7))
                
                # Add labels and title
                ax.set_xlabel('Feature Importance')
                ax.set_ylabel('Sensitivity Ratio')
                ax.set_title('Functional Group Representation: Pre vs. Post Training', fontsize=14)
                
                # Add legend
                import matplotlib.lines as mlines
                arrow = mlines.Line2D([], [], color='black', marker='>', linestyle='-',
                                     markersize=10, label='Pre → Post')
                ax.legend(handles=[arrow], loc='upper left')
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'functional_group_impact.png'), dpi=300)
                plt.close()
    
    print(f"Enhanced visualizations saved to {output_dir}")

In [3]:
def create_additional_visualizations(analyzer, output_dir='./enhanced_figures'):
    """Create additional visualizations focusing on property distribution, 
    feature correlations, functional group clustering, and ring structure mapping"""
    import os
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.decomposition import PCA
    from matplotlib.colors import Normalize
    import networkx as nx
    from sklearn.metrics import silhouette_score
    from scipy.spatial import ConvexHull
    
    os.makedirs(output_dir, exist_ok=True)
    
    # ---- Visualization 3: Property Distribution in Embedding Space ----
    if any(emb_type in analyzer.embedding_data for emb_type in ['pre_training', 'post_training']):
        # Get pre and post training embeddings
        pre_emb = analyzer.embedding_data.get('pre_training', {}).get('embeddings')
        post_emb = analyzer.embedding_data.get('post_training', {}).get('embeddings')
        
        if pre_emb is not None and post_emb is not None:
            # Select a few key properties
            key_properties = [
                # Try to get molecular weight and LogP related properties
                next((c for c in analyzer.prop_df.columns if 'mw' in c.lower()), None),
                next((c for c in analyzer.prop_df.columns if 'logp' in c.lower()), None),
                # Try other common properties
                next((c for c in analyzer.prop_df.columns if 'aromatic' in c.lower()), None),
                next((c for c in analyzer.prop_df.columns if 'ring' in c.lower()), None)
            ]
            
            # Filter out None values
            key_properties = [p for p in key_properties if p is not None]
            
            # If we don't have any of the target properties, select the first few numeric ones
            if not key_properties:
                numeric_cols = analyzer.prop_df.select_dtypes(include=np.number).columns
                key_properties = list(numeric_cols)[:4]  # Take up to 4 properties
            
            if key_properties:  # Only proceed if we have properties
                # Calculate PCA for embeddings
                pca_pre = PCA(n_components=2)
                pca_post = PCA(n_components=2)
                
                pre_2d = pca_pre.fit_transform(pre_emb)
                post_2d = pca_post.fit_transform(post_emb)
                
                # Create grid of plots
                n_props = len(key_properties)
                fig, axes = plt.subplots(n_props, 2, figsize=(14, 4 * n_props))
                
                # If only one property, wrap axes in list
                if n_props == 1:
                    axes = [axes]
                
                for i, prop in enumerate(key_properties):
                    # Get property values
                    prop_values = analyzer.prop_df[prop].values
                    
                    # Create color map
                    norm = Normalize(vmin=np.min(prop_values), vmax=np.max(prop_values))
                    
                    # Plot pre-training
                    ax_pre = axes[i][0]
                    sc_pre = ax_pre.scatter(pre_2d[:, 0], pre_2d[:, 1], 
                                         c=prop_values, cmap='viridis', 
                                         alpha=0.8, norm=norm)
                    ax_pre.set_title(f'Pre-training: {prop}')
                    ax_pre.set_xlabel('PC1')
                    ax_pre.set_ylabel('PC2')
                    
                    # Add contour lines if we have enough points
                    if len(pre_2d) > 10:
                        try:
                            x = pre_2d[:, 0]
                            y = pre_2d[:, 1]
                            
                            # Try to create a 2D histogram and then contour
                            hist, x_edges, y_edges = np.histogram2d(x, y, bins=10)
                            x_centers = (x_edges[:-1] + x_edges[1:]) / 2
                            y_centers = (y_edges[:-1] + y_edges[1:]) / 2
                            
                            X, Y = np.meshgrid(x_centers, y_centers)
                            ax_pre.contour(X, Y, hist.T, colors='black', alpha=0.3, levels=3)
                        except:
                            pass  # Skip contours if they fail
                    
                    # Plot post-training
                    ax_post = axes[i][1]
                    sc_post = ax_post.scatter(post_2d[:, 0], post_2d[:, 1], 
                                           c=prop_values, cmap='viridis', 
                                           alpha=0.8, norm=norm)
                    ax_post.set_title(f'Post-training: {prop}')
                    ax_post.set_xlabel('PC1')
                    ax_post.set_ylabel('PC2')
                    
                    # Add contour lines if we have enough points
                    if len(post_2d) > 10:
                        try:
                            x = post_2d[:, 0]
                            y = post_2d[:, 1]
                            
                            # Try to create a 2D histogram and then contour
                            hist, x_edges, y_edges = np.histogram2d(x, y, bins=10)
                            x_centers = (x_edges[:-1] + x_edges[1:]) / 2
                            y_centers = (y_edges[:-1] + y_edges[1:]) / 2
                            
                            X, Y = np.meshgrid(x_centers, y_centers)
                            ax_post.contour(X, Y, hist.T, colors='black', alpha=0.3, levels=3)
                        except:
                            pass  # Skip contours if they fail
                    
                    # Add colorbar
                    cbar = fig.colorbar(sc_post, ax=[ax_pre, ax_post], orientation='horizontal')
                    cbar.set_label(prop)
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'property_distribution.png'), dpi=300)
                plt.close()
    
    # ---- Visualization 5: Feature Correlation Network ----
    if 'feature_importance' in analyzer.results:
        # Get feature importance data
        feat_df = analyzer.results['feature_importance']
        
        # Filter to relevant features and embedding types
        pre_feats = feat_df[feat_df['embedding_type'] == 'pre_training']
        post_feats = feat_df[feat_df['embedding_type'] == 'post_training']
        
        if not pre_feats.empty and not post_feats.empty:
            # Get top features by importance
            top_n = 10  # Number of top features to include
            
            pre_top = pre_feats.sort_values('importance', ascending=False).head(top_n)
            post_top = post_feats.sort_values('importance', ascending=False).head(top_n)
            
            # Combine unique features from both sets
            all_features = set(pre_top['feature']).union(set(post_top['feature']))
            
            if analyzer.prop_df is not None and len(all_features) > 1:
                # Calculate correlations between these features
                feature_corr = analyzer.prop_df[[f for f in all_features if f in analyzer.prop_df.columns]].corr()
                
                # Create two network plots
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
                
                # Function to create network
                def create_correlation_network(ax, features, importances_df, title):
                    # Create graph
                    G = nx.Graph()
                    
                    # Map feature to importance
                    feat_imp = {f: i for f, i in zip(importances_df['feature'], importances_df['importance'])}
                    
                    # Add nodes
                    for feature in feature_corr.columns:
                        if feature in feat_imp:
                            # Scale node size by importance
                            size = feat_imp[feature] * 3000  # Scale factor
                            G.add_node(feature, size=size)
                    
                    # Add edges for correlations
                    for i, feat1 in enumerate(feature_corr.columns):
                        for j, feat2 in enumerate(feature_corr.columns):
                            if i < j and feat1 in feat_imp and feat2 in feat_imp:
                                corr = abs(feature_corr.loc[feat1, feat2])
                                if corr > 0.3:  # Only show stronger correlations
                                    G.add_edge(feat1, feat2, weight=corr)
                    
                    # Draw network
                    if len(G.nodes) > 1:  # Only draw if we have at least 2 nodes
                        pos = nx.spring_layout(G, seed=42)
                        
                        # Draw nodes
                        node_sizes = [G.nodes[n]['size'] for n in G.nodes]
                        nx.draw_networkx_nodes(G, pos, ax=ax, node_size=node_sizes, 
                                               node_color='skyblue', alpha=0.8)
                        
                        # Draw edges with weights
                        edge_weights = [G[u][v]['weight'] * 3 for u, v in G.edges]
                        nx.draw_networkx_edges(G, pos, ax=ax, width=edge_weights, 
                                               alpha=0.5, edge_color='gray')
                        
                        # Draw labels
                        labels = {n: n.replace('prop_', '').replace('func_', '').replace('ring_', '') 
                                 for n in G.nodes}
                        nx.draw_networkx_labels(G, pos, ax=ax, labels=labels, font_size=8)
                        
                        # Set title
                        ax.set_title(title)
                        ax.axis('off')
                
                # Create networks
                create_correlation_network(ax1, pre_top, pre_top, "Pre-training Feature Correlation Network")
                create_correlation_network(ax2, post_top, post_top, "Post-training Feature Correlation Network")
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'feature_correlation_network.png'), dpi=300)
                plt.close()
    
    # ---- Visualization 7: Functional Group Clustering ----
    func_props = [prop for prop in analyzer.prop_df.columns if 'func_' in prop]
    
    if func_props and 'pre_training' in analyzer.embedding_data and 'post_training' in analyzer.embedding_data:
        pre_emb = analyzer.embedding_data['pre_training']['embeddings']
        post_emb = analyzer.embedding_data['post_training']['embeddings']
        
        # Get top 4 functional groups
        top_func_groups = []
        for prop in func_props:
            # Check if the property has any True values
            if analyzer.prop_df[prop].sum() > 0:
                top_func_groups.append(prop)
                if len(top_func_groups) >= 4:
                    break
        
        if top_func_groups:  # Only proceed if we have functional groups
            # Create PCA projections
            pca_pre = PCA(n_components=2).fit_transform(pre_emb)
            pca_post = PCA(n_components=2).fit_transform(post_emb)
            
            # Create grid of plots
            fig, axes = plt.subplots(len(top_func_groups), 2, figsize=(12, 4 * len(top_func_groups)))
            
            # If only one functional group, wrap axes in list
            if len(top_func_groups) == 1:
                axes = [axes]
            
            for i, func in enumerate(top_func_groups):
                # Get molecules with this functional group
                has_func = analyzer.prop_df[func] > 0
                
                # Plot pre-training
                ax_pre = axes[i][0]
                
                # Plot all points
                ax_pre.scatter(pca_pre[:, 0], pca_pre[:, 1], color='gray', alpha=0.3)
                
                # Highlight molecules with the functional group
                if has_func.any():
                    ax_pre.scatter(pca_pre[has_func, 0], pca_pre[has_func, 1], 
                                  color='red', label=f'Has {func}')
                    
                    # Try to add convex hull
                    try:
                        points = pca_pre[has_func]
                        if len(points) >= 3:  # Need at least 3 points for convex hull
                            hull = ConvexHull(points)
                            for simplex in hull.simplices:
                                ax_pre.plot(points[simplex, 0], points[simplex, 1], 'r-', alpha=0.5)
                    except:
                        pass  # Skip hull if it fails
                
                ax_pre.set_title(f'Pre-training: {func}')
                ax_pre.legend()
                
                # Plot post-training
                ax_post = axes[i][1]
                
                # Plot all points
                ax_post.scatter(pca_post[:, 0], pca_post[:, 1], color='gray', alpha=0.3)
                
                # Highlight molecules with the functional group
                if has_func.any():
                    ax_post.scatter(pca_post[has_func, 0], pca_post[has_func, 1], 
                                   color='red', label=f'Has {func}')
                    
                    # Try to add convex hull
                    try:
                        points = pca_post[has_func]
                        if len(points) >= 3:  # Need at least 3 points for convex hull
                            hull = ConvexHull(points)
                            for simplex in hull.simplices:
                                ax_post.plot(points[simplex, 0], points[simplex, 1], 'r-', alpha=0.5)
                    except:
                        pass  # Skip hull if it fails
                    
                    # Calculate silhouette score if possible
                    try:
                        if sum(has_func) >= 2 and sum(~has_func) >= 2:  # Need at least 2 points in each class
                            pre_score = silhouette_score(pca_pre, has_func)
                            post_score = silhouette_score(pca_post, has_func)
                            
                            # Add silhouette scores to plot
                            ax_pre.text(0.05, 0.95, f'Silhouette: {pre_score:.3f}', 
                                       transform=ax_pre.transAxes, fontsize=9,
                                       verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
                            
                            ax_post.text(0.05, 0.95, f'Silhouette: {post_score:.3f}', 
                                        transform=ax_post.transAxes, fontsize=9,
                                        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
                    except:
                        pass  # Skip silhouette score if it fails
                
                ax_post.set_title(f'Post-training: {func}')
                ax_post.legend()
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'functional_group_clustering.png'), dpi=300)
            plt.close()
    
    # ---- Visualization 9: Ring Structure Embedding Map ----
    ring_props = [prop for prop in analyzer.prop_df.columns if 'ring' in prop.lower()]
    
    if ring_props and 'pre_training' in analyzer.embedding_data and 'post_training' in analyzer.embedding_data:
        pre_emb = analyzer.embedding_data['pre_training']['embeddings']
        post_emb = analyzer.embedding_data['post_training']['embeddings']
        
        # Calculate PCA projections
        pca = PCA(n_components=2)
        # Use the combined embeddings to get consistent components
        combined_emb = np.vstack([pre_emb, post_emb])
        pca_combined = pca.fit_transform(combined_emb)
        
        # Split back into pre and post
        pre_2d = pca_combined[:len(pre_emb)]
        post_2d = pca_combined[len(pre_emb):]
        
        # Categorize ring types
        ring_categories = {
            'single': next((p for p in ring_props if 'single' in p), None),
            'fused': next((p for p in ring_props if 'fused' in p), None),
            'bridged': next((p for p in ring_props if 'bridged' in p), None),
            'spiro': next((p for p in ring_props if 'spiro' in p), None),
            'size_5': next((p for p in ring_props if 'size' in p and '5' in p), None),
            'size_6': next((p for p in ring_props if 'size' in p and '6' in p), None)
        }
        
        # Filter out None values
        ring_categories = {k: v for k, v in ring_categories.items() if v is not None}
        
        if ring_categories:  # Only proceed if we have ring categories
            # Get total ring count if available
            total_rings = next((p for p in ring_props if 'total' in p), None)
            
            # Create figure
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
            
            # Function to plot rings
            def plot_ring_structures(ax, embeddings_2d, title):
                # Plot all molecules as background
                ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], color='lightgray', alpha=0.3)
                
                # Plot each ring category with different marker
                markers = ['o', 's', '^', 'D', 'p', '*']
                colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan']
                
                for i, (category, prop) in enumerate(ring_categories.items()):
                    # Get molecules with this ring type
                    has_ring = analyzer.prop_df[prop] > 0
                    
                    if has_ring.any():
                        # Size by ring count if available
                        if total_rings is not None:
                            # Use total ring count as size, but ensure minimum size
                            sizes = analyzer.prop_df[total_rings].values * 20 + 30
                        else:
                            sizes = 50  # Default size
                        
                        # Plot points with this ring type
                        ax.scatter(embeddings_2d[has_ring, 0], embeddings_2d[has_ring, 1],
                                  marker=markers[i % len(markers)], 
                                  s=sizes if isinstance(sizes, int) else sizes[has_ring],
                                  color=colors[i % len(colors)], 
                                  alpha=0.7, label=category)
                        
                        # Try to add convex hull
                        try:
                            points = embeddings_2d[has_ring]
                            if len(points) >= 3:  # Need at least 3 points for convex hull
                                hull = ConvexHull(points)
                                for simplex in hull.simplices:
                                    ax.plot(points[simplex, 0], points[simplex, 1], 
                                           color=colors[i % len(colors)], alpha=0.3)
                        except:
                            pass  # Skip hull if it fails
                
                ax.set_title(title)
                ax.legend(loc='best')
            
            # Plot pre and post training embeddings
            plot_ring_structures(ax1, pre_2d, "Pre-training Ring Structures")
            plot_ring_structures(ax2, post_2d, "Post-training Ring Structures")
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'ring_structure_embedding_map.png'), dpi=300)
            plt.close()
    
    print(f"Additional visualizations saved to {output_dir}")

In [4]:
if __name__ == "__main__":
    # Set paths to your specific files
    metadata_path = "./embeddings/metadata/molecule_metadata_20250303_112449.pkl"
    
    # Define embedding files with their types
    embedding_files = {
        'pre_training': "./embeddings/pre_training_embeddings_20250303_112449.pkl",
        'post_training': "./embeddings/post_training_embeddings_20250303_112449.pkl",
        'epoch_50': "./embeddings/epoch_50_embeddings_20250303_112449.pkl",
        'final': "./embeddings/final_embeddings_20250303_112446.pkl"
    }
    
    # Initialize analyzer
    analyzer = EmbeddingAnalyzer(metadata_path, embedding_files)
    
    # Load embeddings
    analyzer.load_embeddings()
       
    # After running the full analysis
    analyzer.run_all_analyses()

    # Create all visualizations
    create_enhanced_visualizations(analyzer)
    create_additional_visualizations(analyzer)

Loading embedding files...
Loaded pre_training embeddings: (41, 128)
Loaded post_training embeddings: (41, 128)
Loaded epoch_50 embeddings: (41, 128)
Loaded final embeddings: (41, 128)
Starting comprehensive embedding analysis...
Preparing property data from metadata...
Prepared DataFrame with 41 molecules and 44 properties
Removed 7 low-information columns

=== Property Prediction Analysis ===

Analyzing pre_training embeddings...
prop_num_nodes - Linear: R² = -0.0511
prop_num_nodes - RandomForest: R² = -0.2447
prop_num_edges - Linear: R² = -0.0568
prop_num_edges - RandomForest: R² = -0.4844
prop_avg_node_degree - Linear: R² = -0.2087
prop_avg_node_degree - RandomForest: R² = -0.9227
prop_avg_path_length - Linear: R² = -0.0835
prop_avg_path_length - RandomForest: R² = -0.5899
prop_clustering_coefficient - Linear: R² = -0.0244
prop_clustering_coefficient - RandomForest: R² = -0.0881
prop_graph_diameter - Linear: R² = -0.1670
prop_graph_diameter - RandomForest: R² = -0.7115
prop_assorta

prop_avg_valence - RandomForest: R² = -0.2457
prop_std_valence - Linear: R² = -0.9572
prop_std_valence - RandomForest: R² = -1.4485
prop_avg_degree - Linear: R² = -0.2831
prop_avg_degree - RandomForest: R² = -0.5675
prop_std_degree - Linear: R² = -0.0854
prop_std_degree - RandomForest: R² = -0.4401
feat_density - Linear: R² = -0.0992
feat_density - RandomForest: R² = -0.2238
feat_max_centrality - Linear: R² = -0.1254
feat_max_centrality - RandomForest: R² = -0.1497
feat_avg_centrality - Linear: R² = -0.0992
feat_avg_centrality - RandomForest: R² = -0.3090
func_conjugated_bonds - Linear: R² = -0.3013
func_conjugated_bonds - RandomForest: R² = -0.7953
ring_ring_counts_total - Linear: R² = -0.2567
ring_ring_counts_total - RandomForest: R² = -0.2485
ring_ring_counts_single - Linear: R² = -0.2149
ring_ring_counts_single - RandomForest: R² = -0.3441
ring_ring_counts_fused - Linear: R² = -0.2166
ring_ring_counts_fused - RandomForest: R² = -0.4065
ring_ring_sizes_6 - Linear: R² = -0.2334
ring_


Analyzing post_training embeddings...

Analyzing epoch_50 embeddings...

Analyzing final embeddings...

=== Property Prediction Comparison ===
embedding_type                 property         model  post_training  \
47                 prop_std_is_aromatic  RandomForest      -0.180945   
39                  prop_std_contrib_mw  RandomForest      -0.761629   
7                 func_conjugated_bonds  RandomForest      -0.217024   
43               prop_std_formal_charge  RandomForest      -0.027703   
23                 prop_avg_node_degree  RandomForest      -0.254616   
..                                  ...           ...            ...   
53              ring_ring_counts_single  RandomForest      -0.799493   
1                   feat_avg_centrality  RandomForest      -0.822173   
5                   feat_max_centrality  RandomForest      -0.935173   
51               ring_ring_counts_fused  RandomForest      -0.888237   
35                       prop_num_nodes  RandomForest      -1.01


=== Embedding Visualization ===
Visualizing with properties: ['ring_ring_sizes_7', 'ring_ring_sizes_8', 'prop_std_formal_charge', 'prop_std_hybridization']

Visualizing pre_training embeddings...


  warn(



Visualizing post_training embeddings...


  warn(



Visualizing epoch_50 embeddings...


  warn(



Visualizing final embeddings...


  warn(


Saved visualization figures to ./figures

Analysis complete! Results and visualizations have been saved.


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  radar_df['category'] = radar_df['feature'].apply(group_feature)


Enhanced visualizations saved to ./enhanced_figures


  plt.tight_layout()


Additional visualizations saved to ./enhanced_figures


<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x800 with 0 Axes>