In [1]:
# !pip install lime
# !pip install shap

In [2]:
import torch
import torch.nn as nn
import numpy as np
import shap
import lime
import lime.lime_tabular
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import pickle
import os
from torch_geometric.data import Data, DataLoader
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm

from typing import Dict
import seaborn as sns


In [3]:

class EmbeddingAnalyzer:
    """Analyzer for the learned embeddings"""
    def __init__(self, embedding_path: str):
        """Load and analyze saved embeddings"""
        if not os.path.exists(embedding_path):
            raise FileNotFoundError(f"Embedding file not found: {embedding_path}")
            
        print(f"Loading embeddings from {embedding_path}")
        with open(embedding_path, 'rb') as f:
            data = pickle.load(f)
            self.embeddings = data['embeddings']
            self.graphs = data['labels']
            
        print(f"Loaded {len(self.embeddings)} embeddings of dimension {self.embeddings.shape[1]}")
        
    def analyze_embedding_space(self) -> Dict:
        """Analyze the learned embedding space"""
        # Create output directory
        os.makedirs('embedding_analysis', exist_ok=True)
        
        # 1. Dimensionality Analysis
        embedding_dim = self.embeddings.shape[1]
        dim_variances = np.var(self.embeddings, axis=0)
        
        # Plot dimension variances
        plt.figure(figsize=(10, 6))
        sns.barplot(x=list(range(embedding_dim)), y=dim_variances)
        plt.title('Variance in Each Embedding Dimension')
        plt.xlabel('Dimension')
        plt.ylabel('Variance')
        plt.tight_layout()
        plt.savefig('embedding_analysis/dimension_variances.png')
        plt.close()
        
        # 2. Clustering Analysis
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=5, random_state=42)
        clusters = kmeans.fit_predict(self.embeddings)
        
        # 3. t-SNE Visualization
        from sklearn.manifold import TSNE
        tsne = TSNE(n_components=2, random_state=42)
        embeddings_2d = tsne.fit_transform(self.embeddings)
        
        plt.figure(figsize=(12, 8))
        scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                            c=clusters, cmap='viridis', alpha=0.6)
        plt.colorbar(scatter)
        plt.title('t-SNE Visualization of Embedding Space\nColored by Clusters')
        plt.tight_layout()
        plt.savefig('embedding_analysis/tsne_visualization.png')
        plt.close()
        
        # 4. Correlation Analysis
        correlation_matrix = np.corrcoef(self.embeddings.T)
        plt.figure(figsize=(12, 12))
        sns.heatmap(correlation_matrix, cmap='coolwarm', center=0)
        plt.title('Correlation Between Embedding Dimensions')
        plt.tight_layout()
        plt.savefig('embedding_analysis/dimension_correlations.png')
        plt.close()
        
        return {
            'dimension_stats': {
                'mean': np.mean(self.embeddings, axis=0),
                'std': np.std(self.embeddings, axis=0),
                'variances': dim_variances
            },
            'clustering': {
                'cluster_labels': clusters,
                'cluster_centers': kmeans.cluster_centers_
            },
            'tsne_coords': embeddings_2d,
            'correlation_matrix': correlation_matrix
        }

class EncoderAnalyzer:
    """Analyzer for the encoder weights"""
    def __init__(self, encoder_path: str):
        """Load and analyze saved encoder"""
        if not os.path.exists(encoder_path):
            raise FileNotFoundError(f"Encoder file not found: {encoder_path}")
            
        print(f"Loading encoder from {encoder_path}")
        checkpoint = torch.load(encoder_path, map_location='cpu')
        self.state_dict = checkpoint['encoder_state_dict']
        self.model_info = checkpoint.get('model_info', {})
        
    def analyze_weights(self) -> Dict:
        """Analyze the learned weights of the encoder"""
        # Create output directory
        os.makedirs('encoder_analysis', exist_ok=True)
        
        weight_stats = {}
        for name, param in self.state_dict.items():
            if 'weight' in name:
                # Convert to numpy for analysis
                weights = param.cpu().numpy()
                
                # Calculate statistics
                weight_stats[name] = {
                    'mean': float(np.mean(weights)),
                    'std': float(np.std(weights)),
                    'min': float(np.min(weights)),
                    'max': float(np.max(weights))
                }
                
                # Plot weight distribution
                plt.figure(figsize=(10, 6))
                sns.histplot(weights.flatten(), bins=50)
                plt.title(f'Weight Distribution for {name}')
                plt.xlabel('Weight Value')
                plt.ylabel('Count')
                plt.tight_layout()
                plt.savefig(f'encoder_analysis/{name}_distribution.png')
                plt.close()
        
        return weight_stats

def main():
    """Analyze saved embeddings and encoder"""
    # Paths to saved files
    EMBEDDING_PATH = './embeddings/final_embeddings.pkl'
    ENCODER_PATH = './checkpoints/encoders/best_encoder.pt'
    
    try:
        # 1. Analyze Embeddings
        print("\nAnalyzing embeddings...")
        embedding_analyzer = EmbeddingAnalyzer(EMBEDDING_PATH)
        embedding_results = embedding_analyzer.analyze_embedding_space()
        
        print("\nEmbedding Analysis Results:")
        print(f"- Number of embeddings: {len(embedding_analyzer.embeddings)}")
        print(f"- Embedding dimension: {embedding_analyzer.embeddings.shape[1]}")
        print(f"- Number of clusters found: {len(np.unique(embedding_results['clustering']['cluster_labels']))}")
        
        # 2. Analyze Encoder
        print("\nAnalyzing encoder...")
        encoder_analyzer = EncoderAnalyzer(ENCODER_PATH)
        weight_stats = encoder_analyzer.analyze_weights()
        
        print("\nEncoder Analysis Results:")
        print("Layer weight statistics:")
        for layer_name, stats in weight_stats.items():
            print(f"\n{layer_name}:")
            for stat_name, value in stats.items():
                print(f"  {stat_name}: {value:.4f}")
        
        print("\nAnalysis complete! Visualizations saved in:")
        print("- embedding_analysis/")
        print("- encoder_analysis/")
        
    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        raise

if __name__ == "__main__":
    main()


Analyzing embeddings...
Loading embeddings from ./embeddings/final_embeddings.pkl
Loaded 397 embeddings of dimension 128





Embedding Analysis Results:
- Number of embeddings: 397
- Embedding dimension: 128
- Number of clusters found: 5

Analyzing encoder...
Loading encoder from ./checkpoints/encoders/best_encoder.pt

Encoder Analysis Results:
Layer weight statistics:

node_encoder.0.weight:
  mean: -0.0033
  std: 0.1228
  min: -0.2614
  max: 0.2277

node_encoder.2.weight:
  mean: 0.0006
  std: 0.0894
  min: -0.1946
  max: 0.2081

edge_encoder.0.weight:
  mean: 0.0054
  std: 0.1231
  min: -0.2120
  max: 0.2124

edge_encoder.2.weight:
  mean: -0.0002
  std: 0.0884
  min: -0.1531
  max: 0.1531

conv1.lin.weight:
  mean: -0.0007
  std: 0.0891
  min: -0.1997
  max: 0.2052

conv2.lin.weight:
  mean: -0.0014
  std: 0.0888
  min: -0.2155
  max: 0.2167

conv3.lin.weight:
  mean: -0.0021
  std: 0.0887
  min: -0.2365
  max: 0.2032

projection.0.weight:
  mean: -0.0013
  std: 0.0882
  min: -0.1954
  max: 0.2025

projection.2.weight:
  mean: -0.0005
  std: 0.0884
  min: -0.2278
  max: 0.2516

Analysis complete! Visual