In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Import molecular libraries
from deepchem.molnet import load_bace_classification, load_bbbp, load_clintox, load_delaney, load_qm9
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

# XAI and ML libraries
import shap
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, r2_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

import os
from tqdm import tqdm
import pickle
from scipy import stats

# Graph neural network imports
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data

# Set style for publication-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

class MolecularFeatureExtractor:
    """Feature extractor for molecular graphs"""
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        try:
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]
            phys_feat = []

            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)

            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)

            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])

            return atom_feat, phys_feat

        except:
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        atom_feats = []
        phys_feats = []

        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)

        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)

        return x, phys

    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        row, col, edge_feat = [], [], []

        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]

                bond_type = self.bond_list.index(bond.GetBondType())
                feat = [bond_type, 0, int(bond.GetIsConjugated()), 0, 0]
                edge_feat.extend([feat, feat])

            except:
                continue

        if not row:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)

        return edge_index, edge_attr

    def process_molecule(self, smiles: str) -> Data:
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None

            Chem.SanitizeMol(mol)
            if mol.GetNumAtoms() == 0:
                return None

            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except:
            return None

class GraphDiscriminator(nn.Module):
    """Encoder network for molecular graphs"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()

        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)

        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def normalize_features(self, x_cat, x_phys):
        x_cat = x_cat.float()
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
        return x_cat, x_phys

    def forward(self, data):
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        x = torch.cat([x_cat, x_phys], dim=-1)

        edge_index = data.edge_index
        batch = data.batch

        x = self.node_encoder(x)

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)
        x = self.projection(x)

        return x

class MemorizationBiasSHAPAnalyzer:
    """Generate SHAP visualizations for memorization bias"""

    def __init__(self, encoder_path: str, output_dir: str = './Memorization-SHAP'):
        self.encoder_path = encoder_path
        self.output_dir = output_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        os.makedirs(output_dir, exist_ok=True)

        # Load encoder
        self.encoder, self.model_info = self.load_encoder()
        self.extractor = MolecularFeatureExtractor()

        # Molecular descriptors
        self.descriptor_names = [
            'MolWt', 'LogP', 'TPSA', 'NumRotatableBonds',
            'NumHDonors', 'NumHAcceptors', 'NumAromaticRings',
            'NumSaturatedRings', 'NumAliphaticRings', 'BertzCT'
        ]

    def load_encoder(self):
        checkpoint = torch.load(self.encoder_path, map_location=self.device)
        encoder = GraphDiscriminator(
            node_dim=checkpoint['model_info']['node_dim'],
            edge_dim=checkpoint['model_info']['edge_dim'],
            hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
            output_dim=checkpoint['model_info'].get('output_dim', 128)
        )
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        encoder.to(self.device)
        encoder.eval()
        return encoder, checkpoint['model_info']

    def compute_molecular_descriptors(self, smiles_list):
        """Compute molecular descriptors"""
        descriptors = []
        for smiles in smiles_list:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                descriptors.append([0] * len(self.descriptor_names))
                continue

            desc = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumRotatableBonds(mol),
                Descriptors.NumHDonors(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.NumAromaticRings(mol),
                Descriptors.NumSaturatedRings(mol),
                Descriptors.NumAliphaticRings(mol),
                Descriptors.BertzCT(mol)
            ]
            descriptors.append(desc)

        return np.array(descriptors)

    def compute_tanimoto_similarity(self, test_smiles, train_smiles, k=5):
        """Compute Tanimoto similarity using Morgan fingerprints"""
        from rdkit import DataStructs
        from rdkit.Chem import AllChem

        print("Computing Tanimoto similarities...")

        # Generate fingerprints for all molecules
        test_fps = []
        for smiles in test_smiles:
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
                test_fps.append(fp)
            else:
                test_fps.append(None)

        train_fps = []
        for smiles in train_smiles:
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
                train_fps.append(fp)
            else:
                train_fps.append(None)

        # Compute similarities
        nearest_k_similarities = []
        for test_fp in test_fps:
            if test_fp is None:
                nearest_k_similarities.append(0.5)  # Default for invalid molecules
                continue

            similarities = []
            for train_fp in train_fps:
                if train_fp is not None:
                    sim = DataStructs.TanimotoSimilarity(test_fp, train_fp)
                    similarities.append(sim)

            if similarities:
                top_k = sorted(similarities, reverse=True)[:k]
                nearest_k_similarities.append(np.mean(top_k))
            else:
                nearest_k_similarities.append(0.5)

        return np.array(nearest_k_similarities)

    def extract_embeddings(self, smiles_list):
        """Extract embeddings"""
        embeddings = []
        valid_indices = []

        for idx, smiles in enumerate(tqdm(smiles_list, desc="Extracting embeddings")):
            try:
                data = self.extractor.process_molecule(smiles)
                if data is not None:
                    data.batch = torch.zeros(data.x_cat.size(0), dtype=torch.long)
                    data = data.to(self.device)

                    with torch.no_grad():
                        emb = self.encoder(data)
                        embeddings.append(emb.cpu().numpy().squeeze())
                        valid_indices.append(idx)
            except:
                continue

        if embeddings:
            return np.vstack(embeddings), valid_indices
        return np.empty((0, 128)), []

    def compute_similarity_scores(self, test_embeddings, train_embeddings, k=5):
        """Compute similarity scores"""
        similarities = cosine_similarity(test_embeddings, train_embeddings)
        nearest_k_similarities = []
        for sim_row in similarities:
            top_k = np.sort(sim_row)[-k:]
            nearest_k_similarities.append(np.mean(top_k))

        return np.array(nearest_k_similarities)

    def categorize_by_similarity(self, similarity_scores):
        """Categorize into quartiles"""
        quartiles = np.percentile(similarity_scores, [25, 50, 75])

        categories = np.zeros(len(similarity_scores), dtype=int)
        categories[similarity_scores <= quartiles[0]] = 0  # Q1: Novel
        categories[(similarity_scores > quartiles[0]) & (similarity_scores <= quartiles[1])] = 1
        categories[(similarity_scores > quartiles[1]) & (similarity_scores <= quartiles[2])] = 2
        categories[similarity_scores > quartiles[2]] = 3  # Q4: Similar

        return categories, quartiles

    def visualization_1_split_summary(self, shap_values, X_test, categories, dataset_name):
        """Option 1: SHAP Summary Plot Split by Similarity with fixed feature order"""
        fig = plt.figure(figsize=(16, 8))

        # Create grid with space for colorbar
        gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.05], wspace=0.3)
        ax1 = fig.add_subplot(gs[0, 0])
        ax2 = fig.add_subplot(gs[0, 1])
        cbar_ax = fig.add_subplot(gs[0, 2])

        axes = [ax1, ax2]

        # Define fixed feature order
        feature_order = list(range(len(self.descriptor_names)))

        # Novel samples (Q1) and Similar samples (Q4)
        novel_mask = categories == 0
        similar_mask = categories == 3

        for ax_idx, (mask, title) in enumerate([(novel_mask, 'Novel Samples (Low Similarity to Training)'),
                                                  (similar_mask, 'Similar Samples (High Similarity to Training)')]):
            if sum(mask) > 0:
                ax = axes[ax_idx]

                # Get data for this subset
                shap_subset = shap_values[mask]
                X_subset = X_test[mask]

                # Create custom summary plot with fixed order
                for i, feat_idx in enumerate(feature_order):
                    # Get SHAP and feature values for this feature
                    shap_vals = shap_subset[:, feat_idx]
                    feat_vals = X_subset[:, feat_idx]

                    # Normalize feature values for color mapping
                    feat_normalized = (feat_vals - feat_vals.min()) / (feat_vals.max() - feat_vals.min() + 1e-10)

                    # Create scatter plot
                    scatter = ax.scatter(shap_vals, [i]*len(shap_vals),
                                       c=feat_normalized, cmap='coolwarm',
                                       alpha=0.6, s=20, vmin=0, vmax=1)

                ax.set_yticks(range(len(self.descriptor_names)))
                ax.set_yticklabels(self.descriptor_names)
                ax.set_xlabel('SHAP value (impact on model output)', fontsize=11)
                ax.set_ylabel('Feature value', fontsize=11)
                ax.set_title(title, fontsize=12, fontweight='bold')
                ax.grid(True, alpha=0.3)
                ax.axvline(0, color='black', linestyle='-', linewidth=0.5)

        # Add colorbar in dedicated axis
        sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=plt.Normalize(vmin=0, vmax=1))
        sm.set_array([])
        cbar = plt.colorbar(sm, cax=cbar_ax)
        cbar.set_label('Feature value', rotation=270, labelpad=15)
        cbar.set_ticks([0, 0.5, 1])
        cbar.set_ticklabels(['Low', '', 'High'])

        plt.suptitle(f'{dataset_name.upper()} - SHAP Patterns: Novel vs Similar Samples\n'
                    'Scattered patterns (left) indicate uncertainty; Concentrated patterns (right) indicate memorization',
                    fontsize=16, fontweight='bold')

        save_path = os.path.join(self.output_dir, f'{dataset_name}_1_split_summary.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {save_path}")

    def visualization_2_variance_analysis(self, shap_values, similarity_scores, dataset_name):
        """Option 2: SHAP Value Variance vs Similarity"""
        fig, ax = plt.subplots(figsize=(10, 6))

        # Handle different SHAP value formats
        if len(shap_values.shape) == 3:
            # For binary classification: shape (n_samples, n_features, n_classes)
            # Use the positive class (index 1) or take mean across classes
            shap_values = shap_values[:, :, 1] if shap_values.shape[2] == 2 else shap_values.mean(axis=2)
        elif len(shap_values.shape) == 1:
            # If 1D, reshape to 2D
            shap_values = shap_values.reshape(-1, 1)

        # Calculate variance of SHAP values for each sample
        shap_variance = np.var(shap_values, axis=1)

        # Scatter plot with trend line
        ax.scatter(similarity_scores, shap_variance, alpha=0.5, s=30)

        # Add trend line
        z = np.polyfit(similarity_scores, shap_variance, 1)
        p = np.poly1d(z)
        x_trend = np.linspace(similarity_scores.min(), similarity_scores.max(), 100)
        ax.plot(x_trend, p(x_trend), "r-", linewidth=2, label=f'Trend (slope={z[0]:.2f})')

        # Add quartile boundaries
        quartiles = np.percentile(similarity_scores, [25, 50, 75])
        for q, label in zip(quartiles, ['Q1/Q2', 'Q2/Q3', 'Q3/Q4']):
            ax.axvline(q, color='gray', linestyle='--', alpha=0.5, label=label)

        ax.set_xlabel('Similarity to Training Set', fontsize=12, fontweight='bold')
        ax.set_ylabel('SHAP Value Variance', fontsize=12, fontweight='bold')
        ax.set_title(f'{dataset_name.upper()} - Model Uncertainty vs Sample Novelty\n'
                    'High variance indicates uncertain feature importance',
                    fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

        save_path = os.path.join(self.output_dir, f'{dataset_name}_2_variance_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {save_path}")

    def visualization_3_cohesion_score(self, shap_values, X_test, similarity_scores, dataset_name):
        """Option 3: SHAP Cohesion Score"""
        fig, ax = plt.subplots(figsize=(10, 6))

        # Calculate cohesion: alignment between SHAP importance and actual feature magnitude
        cohesion_scores = []
        for i in range(len(shap_values)):
            # Normalize features and SHAP values
            feat_norm = np.abs(X_test[i]) / (np.abs(X_test[i]).sum() + 1e-10)
            shap_norm = np.abs(shap_values[i]) / (np.abs(shap_values[i]).sum() + 1e-10)
            # Cohesion as correlation between feature magnitude and SHAP importance
            cohesion = np.corrcoef(feat_norm, shap_norm)[0, 1]
            cohesion_scores.append(cohesion if not np.isnan(cohesion) else 0)

        cohesion_scores = np.array(cohesion_scores)

        # Plot with color gradient
        scatter = ax.scatter(similarity_scores, cohesion_scores,
                           c=similarity_scores, cmap='RdYlBu',
                           alpha=0.6, s=50, edgecolors='black', linewidth=0.5)

        # Add trend line
        mask = ~np.isnan(cohesion_scores)
        if sum(mask) > 1:
            z = np.polyfit(similarity_scores[mask], cohesion_scores[mask], 1)
            p = np.poly1d(z)
            x_trend = np.linspace(similarity_scores.min(), similarity_scores.max(), 100)
            ax.plot(x_trend, p(x_trend), "k-", linewidth=2,
                   label=f'Trend (correlation={np.corrcoef(similarity_scores[mask], cohesion_scores[mask])[0,1]:.3f})')

        plt.colorbar(scatter, label='Similarity Score')

        ax.set_xlabel('Similarity to Training Set', fontsize=12, fontweight='bold')
        ax.set_ylabel('SHAP Cohesion Score', fontsize=12, fontweight='bold')
        ax.set_title(f'{dataset_name.upper()} - Explanation Coherence vs Sample Familiarity\n'
                    'Higher cohesion means model explanations align with feature magnitudes',
                    fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

        save_path = os.path.join(self.output_dir, f'{dataset_name}_3_cohesion_score.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {save_path}")

    def visualization_4_interaction_heatmap(self, shap_values, categories, dataset_name):
        """Option 4: SHAP Feature Interaction Patterns"""
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))

        # Calculate feature correlation matrices for novel vs similar
        novel_mask = categories == 0
        similar_mask = categories == 3

        if sum(novel_mask) > 5 and sum(similar_mask) > 5:
            # Novel samples correlation
            novel_shap = shap_values[novel_mask]
            novel_corr = np.corrcoef(novel_shap.T)

            # Similar samples correlation
            similar_shap = shap_values[similar_mask]
            similar_corr = np.corrcoef(similar_shap.T)

            # Plot heatmaps
            sns.heatmap(novel_corr, ax=axes[0], cmap='coolwarm', center=0,
                       xticklabels=self.descriptor_names,
                       yticklabels=self.descriptor_names,
                       cbar_kws={'label': 'Correlation'})
            axes[0].set_title('Novel Samples\nWeak/Inconsistent Interactions', fontsize=12, fontweight='bold')

            sns.heatmap(similar_corr, ax=axes[1], cmap='coolwarm', center=0,
                       xticklabels=self.descriptor_names,
                       yticklabels=self.descriptor_names,
                       cbar_kws={'label': 'Correlation'})
            axes[1].set_title('Similar Samples\nStrong/Memorized Interactions', fontsize=12, fontweight='bold')

            # Rotate labels
            for ax in axes:
                ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
                ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

        plt.suptitle(f'{dataset_name.upper()} - Feature Interaction Patterns\n'
                    'Stronger correlations in similar samples indicate memorized feature combinations',
                    fontsize=14, fontweight='bold')
        plt.tight_layout()

        save_path = os.path.join(self.output_dir, f'{dataset_name}_4_interaction_heatmap.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {save_path}")

    def load_dataset(self, dataset_name):
        """Load dataset"""
        print(f"\nLoading {dataset_name.upper()} dataset...")

        try:
            if dataset_name == 'bace':
                tasks, datasets, _ = load_bace_classification(featurizer='ECFP', splitter='scaffold')
                task_type = 'classification'
            elif dataset_name == 'bbbp':
                tasks, datasets, _ = load_bbbp(featurizer='ECFP', splitter='scaffold')
                task_type = 'classification'
            elif dataset_name == 'clintox':
                tasks, datasets, _ = load_clintox(featurizer='ECFP', splitter='random')
                task_type = 'classification'
            elif dataset_name == 'esol':
                try:
                    tasks, datasets, _ = load_delaney(featurizer='ECFP', splitter='scaffold')
                except:
                    tasks, datasets, _ = load_delaney(featurizer='ECFP', splitter='random')
                task_type = 'regression'
            elif dataset_name == 'qm9':
                tasks, datasets, _ = load_qm9(featurizer='ECFP', splitter='random')
                task_type = 'regression'
            else:
                raise ValueError(f"Unknown dataset: {dataset_name}")

            train_dataset, _, test_dataset = datasets

            train_smiles = train_dataset.ids.tolist()
            test_smiles = test_dataset.ids.tolist()

            if len(train_dataset.y.shape) > 1:
                train_y = train_dataset.y[:, 0].flatten()
                test_y = test_dataset.y[:, 0].flatten()
            else:
                train_y = train_dataset.y.flatten()
                test_y = test_dataset.y.flatten()

            return {
                'train_smiles': train_smiles,
                'test_smiles': test_smiles,
                'train_y': train_y,
                'test_y': test_y,
                'task_type': task_type
            }
        except Exception as e:
            print(f"Error loading {dataset_name}: {e}")
            return None

    def analyze_dataset(self, dataset_name):
        """Generate all SHAP visualizations for a dataset"""
        print(f"\n{'='*60}")
        print(f"Generating SHAP Visualizations for {dataset_name.upper()}")
        print(f"{'='*60}")

        # Load dataset
        data = self.load_dataset(dataset_name)
        if data is None:
            return None

        # Extract embeddings for filtering only
        print("\nExtracting embeddings for validation...")
        train_embeddings, train_valid_idx = self.extract_embeddings(data['train_smiles'])
        test_embeddings, test_valid_idx = self.extract_embeddings(data['test_smiles'])

        # Filter data
        train_y = data['train_y'][train_valid_idx]
        test_y = data['test_y'][test_valid_idx]
        train_smiles_valid = [data['train_smiles'][i] for i in train_valid_idx]
        test_smiles_valid = [data['test_smiles'][i] for i in test_valid_idx]

        # Remove NaN
        train_mask = ~np.isnan(train_y)
        test_mask = ~np.isnan(test_y)

        train_y = train_y[train_mask]
        test_y = test_y[test_mask]
        train_smiles_valid = [s for s, m in zip(train_smiles_valid, train_mask) if m]
        test_smiles_valid = [s for s, m in zip(test_smiles_valid, test_mask) if m]

        print(f"Valid samples - Train: {len(train_y)}, Test: {len(test_y)}")

        # Compute molecular descriptors
        print("\nComputing molecular descriptors...")
        test_descriptors = self.compute_molecular_descriptors(test_smiles_valid)

        # Ensure all arrays have same length by creating a common mask
        # valid_desc_mask = np.all(test_descriptors != 0, axis=1)  # Remove zero descriptor rows
        valid_desc_mask = np.sum(test_descriptors == 0, axis=1) < 5

        # Apply mask to all arrays
        test_descriptors = test_descriptors[valid_desc_mask]
        test_y = test_y[valid_desc_mask]
        test_smiles_valid = [s for i, s in enumerate(test_smiles_valid) if valid_desc_mask[i]]

        print(f"After descriptor validation: {len(test_y)} samples")

        # Standardize
        scaler = StandardScaler()
        test_descriptors_scaled = scaler.fit_transform(test_descriptors)

        # Compute Tanimoto similarity instead of cosine
        print("\nComputing Tanimoto similarity scores...")
        similarity_scores = self.compute_tanimoto_similarity(test_smiles_valid, train_smiles_valid, k=5)

        print(f"Similarity range: {similarity_scores.min():.3f} to {similarity_scores.max():.3f}")
        print(f"Mean similarity: {similarity_scores.mean():.3f} Â± {similarity_scores.std():.3f}")

        # Categorize by similarity
        categories, quartiles = self.categorize_by_similarity(similarity_scores)
        print(f"Similarity quartiles: {[f'{q:.3f}' for q in quartiles]}")
        print(f"Samples per quartile: {[sum(categories == q) for q in range(4)]}")

        # Train interpretable model for SHAP
        print("\nTraining interpretable model...")
        if data['task_type'] == 'classification':
            model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)
        else:
            model = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)

        model.fit(test_descriptors_scaled, test_y)

        # Compute SHAP values
        print("\nComputing SHAP values...")
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(test_descriptors_scaled)

        # Normalize SHAP values format
        if isinstance(shap_values, list):
            # For multiclass, take the positive class
            shap_values = shap_values[1] if len(shap_values) == 2 else np.array(shap_values).mean(axis=0)
        elif len(shap_values.shape) == 3:
            # For binary classification with 3D array
            shap_values = shap_values[:, :, 1] if shap_values.shape[2] == 2 else shap_values.mean(axis=2)

        # Ensure 2D shape
        if len(shap_values.shape) == 1:
            shap_values = shap_values.reshape(-1, 1)

        print(f"Final SHAP shape: {shap_values.shape}")

        # Generate all visualizations
        print("\nGenerating visualizations...")
        self.visualization_1_split_summary(shap_values, test_descriptors_scaled, categories, dataset_name)
        self.visualization_2_variance_analysis(shap_values, similarity_scores, dataset_name)
        self.visualization_3_cohesion_score(shap_values, test_descriptors_scaled, similarity_scores, dataset_name)
        self.visualization_4_interaction_heatmap(shap_values, categories, dataset_name)

        print(f"\nAll visualizations saved for {dataset_name}")

        return {
            'dataset': dataset_name,
            'shap_values': shap_values,
            'similarity_scores': similarity_scores,
            'categories': categories
        }

    def run_all_analyses(self):
        """Run analysis for all datasets"""
        # datasets = ['bace', 'bbbp', 'clintox', 'esol', 'qm9']
        datasets = ['esol']
        all_results = {}

        for dataset_name in datasets:
            try:
                results = self.analyze_dataset(dataset_name)
                if results:
                    all_results[dataset_name] = results
            except Exception as e:
                print(f"Error analyzing {dataset_name}: {e}")
                import traceback
                traceback.print_exc()
                continue

        print(f"\n{'='*60}")
        print("ALL SHAP VISUALIZATIONS COMPLETE")
        print(f"{'='*60}")

        return all_results

# Main execution
if __name__ == "__main__":
    encoder_path = './checkpoints/encoders/final_encoder_20250815_125248.pt'
    output_dir = './Memorization-SHAP'

    analyzer = MemorizationBiasSHAPAnalyzer(encoder_path, output_dir)
    results = analyzer.run_all_analyses()

    print(f"\nAll SHAP visualizations saved to: {output_dir}")