In [None]:
!pip install kaleido

In [None]:
import kaleido

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Import molecular libraries
from deepchem.molnet import load_bace_classification, load_bbbp, load_clintox, load_delaney, load_qm9
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

# DALEX for fairness
import dalex as dx

# For visualizations
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, mean_squared_error, r2_score
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import train_test_split

import os
from tqdm import tqdm
import traceback

# Graph neural network imports (from your code)
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# Set style for visualizations
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

class MolecularFeatureExtractor:
    """Feature extractor for molecular graphs"""
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]

    def process_molecule(self, smiles: str) -> Dict:
        """Extract molecular features for fairness analysis"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None

            # Basic molecular descriptors
            features = {
                'MW': Descriptors.ExactMolWt(mol),
                'LogP': Descriptors.MolLogP(mol),
                'HBD': Descriptors.NumHDonors(mol),
                'HBA': Descriptors.NumHAcceptors(mol),
                'TPSA': Descriptors.TPSA(mol),
                'RotBonds': Descriptors.NumRotatableBonds(mol),
                'AromaticRings': Descriptors.NumAromaticRings(mol),
                'SP3_Fraction': Descriptors.FractionCSP3(mol),
                'QED': Descriptors.qed(mol)
            }

            # Functional groups
            functional_groups = {
                'primary_amine': '[NX3H2]',
                'secondary_amine': '[NX3H1]',
                'tertiary_amine': '[NX3H0]',
                'amide': '[NX3][CX3](=[OX1])',
                'chloro': '[Cl]',
                'bromo': '[Br]',
                'fluoro': '[F]',
                'iodo': '[I]',
                'alcohol': '[OH]',
                'phenol': '[OH1][c]',
                'ether': '[OD2]([#6])[#6]',
                'ester': '[#6][CX3](=O)[OX2H0][#6]',
                'carbonyl': '[CX3]=O',
                'aldehyde': '[CX3H1](=O)',
                'ketone': '[#6][CX3](=O)[#6]',
                'carboxyl': '[CX3](=O)[OX2H1]',
                'sulfate': '[$(S(=O)(=O)(O)[O-,OH])]',
                'nitro': '[N+](=O)[O-]'
            }

            for name, smarts in functional_groups.items():
                pattern = Chem.MolFromSmarts(smarts)
                if pattern:
                    matches = mol.GetSubstructMatches(pattern)
                    features[f'{name}_count'] = len(matches)
                else:
                    features[f'{name}_count'] = 0

            return features

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None

class SimpleWrapper:
    """Wrapper class for DALEX predictions"""
    def __init__(self, predictions):
        self.predictions = predictions

    def predict(self, X, *args, **kwargs):
        # Handle any additional arguments DALEX might pass
        if isinstance(X, pd.DataFrame):
            return self.predictions[:len(X)]
        elif isinstance(X, np.ndarray):
            return self.predictions[:len(X)]
        return self.predictions

class MultiDatasetFairnessAnalyzer:
    """Comprehensive fairness analyzer for multiple molecular datasets"""

    def __init__(self, output_dir='./DALEX-Fairness-Analysis'):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.feature_extractor = MolecularFeatureExtractor()

        # Define compliance criteria for each dataset
        self.compliance_criteria = {
            'bace': {
                'name': 'BACE_favorable',
                'check': lambda df: (df['MW'] >= 300) & (df['MW'] <= 600) &
                                   (df['LogP'] >= 1) & (df['LogP'] <= 5)
            },
            'bbbp': {
                'name': 'BBB_favorable',
                'check': lambda df: (df['MW'] <= 400) & (df['LogP'] <= 5) &
                                   (df['HBD'] <= 3) & (df['TPSA'] <= 90)
            },
            'clintox': {
                'name': 'RO5_compliant',
                'check': lambda df: (df['MW'] <= 500) & (df['LogP'] <= 5) &
                                   (df['HBD'] <= 5) & (df['HBA'] <= 10)
            },
            'esol': {
                'name': 'ESOL_favorable',
                'check': lambda df: (df['MW'] <= 500) & (df['LogP'] <= 5) &
                                   (df['TPSA'] <= 140) & (df['RotBonds'] <= 10)
            },
            'qm9': {
                'name': 'QM9_compliant',
                'check': lambda df: (df['MW'] <= 250) & (df['RotBonds'] <= 10)
            }
        }

    def load_dataset(self, dataset_name):
        """Load and preprocess dataset"""
        print(f"\nLoading {dataset_name.upper()} dataset...")

        try:
            if dataset_name == 'bace':
                tasks, datasets, _ = load_bace_classification(featurizer='ECFP', splitter='random')
                task_type = 'classification'
            elif dataset_name == 'bbbp':
                tasks, datasets, _ = load_bbbp(featurizer='ECFP', splitter='random')
                task_type = 'classification'
            elif dataset_name == 'clintox':
                tasks, datasets, _ = load_clintox(featurizer='ECFP', splitter='random')
                task_type = 'classification'
            elif dataset_name == 'esol':
                tasks, datasets, _ = load_delaney(featurizer='ECFP', splitter='random')
                task_type = 'regression'
            elif dataset_name == 'qm9':
                tasks, datasets, _ = load_qm9(featurizer='ECFP', splitter='random', reload=False)
                task_type = 'regression'
                # tasks, datasets, _ = load_qm9(featurizer='ECFP', splitter='random')
                # task_type = 'regression'
            else:
                raise ValueError(f"Unknown dataset: {dataset_name}")

            train_dataset, valid_dataset, test_dataset = datasets

            # Combine datasets
            all_smiles = train_dataset.ids.tolist() + valid_dataset.ids.tolist() + test_dataset.ids.tolist()

            if len(train_dataset.y.shape) > 1:
                all_y = np.concatenate([train_dataset.y[:, 0], valid_dataset.y[:, 0], test_dataset.y[:, 0]])
            else:
                all_y = np.concatenate([train_dataset.y, valid_dataset.y, test_dataset.y]).flatten()

            # Remove NaN values
            valid_mask = ~np.isnan(all_y)
            all_smiles = [s for s, m in zip(all_smiles, valid_mask) if m]
            all_y = all_y[valid_mask]

            # For regression, convert to binary for fairness analysis
            if task_type == 'regression':
                median_val = np.nanmedian(all_y)
                all_y_binary = (all_y > median_val).astype(int)
            else:
                all_y_binary = all_y.astype(int)

            # Split into train and test
            train_size = int(0.8 * len(all_smiles))
            train_smiles = all_smiles[:train_size]
            test_smiles = all_smiles[train_size:]
            train_y = all_y[:train_size]
            test_y = all_y[train_size:]
            train_y_binary = all_y_binary[:train_size]
            test_y_binary = all_y_binary[train_size:]

            return {
                'train_smiles': train_smiles,
                'test_smiles': test_smiles,
                'train_y': train_y,
                'test_y': test_y,
                'train_y_binary': train_y_binary,
                'test_y_binary': test_y_binary,
                'task_type': task_type,
                'task_name': tasks[0] if tasks else dataset_name
            }

        except Exception as e:
            print(f"Error loading {dataset_name}: {e}")
            return None

    def extract_features(self, smiles_list):
        """Extract molecular features from SMILES"""
        features = []
        valid_indices = []

        for idx, smiles in enumerate(tqdm(smiles_list, desc="Extracting features")):
            feat = self.feature_extractor.process_molecule(smiles)
            if feat is not None:
                features.append(feat)
                valid_indices.append(idx)

        return pd.DataFrame(features), valid_indices

    def train_model(self, X_train, y_train, task_type):
        """Train a simple model for fairness analysis"""
        if task_type == 'classification':
            model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)
        else:
            model = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)

        model.fit(X_train, y_train)
        return model

    def create_dalex_fairness_plots(self, explainer, protected, dataset_name, criteria_name):
        """Generate DALEX fairness visualizations using native DALEX methods"""
        print(f"\nGenerating DALEX fairness plots for {dataset_name}...")

        # Create dataset-specific directory
        dataset_dir = os.path.join(self.output_dir, dataset_name)
        os.makedirs(dataset_dir, exist_ok=True)

        try:
            # Ensure protected is a proper pandas Series
            if not isinstance(protected, pd.Series):
                protected = pd.Series(protected)

            protected = protected.reset_index(drop=True)

            print(f"Protected attribute shape: {protected.shape}")
            print(f"Protected attribute unique values: {protected.unique()}")

            # Create DALEX fairness object
            fobject = explainer.model_fairness(
                protected=protected,
                privileged='compliant' if 'compliant' in protected.unique() else protected.unique()[0]
            )

            # Get fairness check results
            print("\nRunning fairness check...")
            fairness_check = fobject.fairness_check(epsilon=0.8, verbose=True)

            # Get fairness results DataFrame
            fairness_results = fobject.result
            print(f"\nFairness Metrics for {dataset_name}:")
            print(fairness_results)

            # Save metrics to file
            metrics_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_metrics.txt')
            with open(metrics_file, 'w') as f:
                f.write(f"DALEX Fairness Metrics for {dataset_name.upper()}\n")
                f.write(f"Criteria: {criteria_name}\n")
                f.write("="*50 + "\n")
                f.write(str(fairness_results))
                f.write("\n\nFairness Check Summary:\n")
                f.write(str(fairness_check))
            print(f"Saved metrics to: {metrics_file}")

            # Fix plotly compatibility issue by downgrading or patching
            import plotly.graph_objects as go

            # Patch the plotly issue
            import plotly
            if hasattr(plotly, '__version__'):
                print(f"Plotly version: {plotly.__version__}")

            # Generate DALEX fairness plots with proper error handling
            print("\nGenerating DALEX fairness visualizations...")

            try:
                # Method 1: Generate fairness check bar plot (like Image 1)
                print("Creating fairness metrics bar plot...")

                # Calculate fairness differences
                if len(fairness_results) > 1:
                    # Get the unprivileged group (non-compliant)
                    unpriv_group = [g for g in fairness_results.index if g != 'compliant'][0]

                    # Calculate differences for standard fairness metrics
                    differences = {
                        'Accuracy Equality': fairness_results.loc[unpriv_group, 'ACC'] - fairness_results.loc['compliant', 'ACC'] if 'ACC' in fairness_results.columns else 0,
                        'Predictive Parity': fairness_results.loc[unpriv_group, 'PPV'] - fairness_results.loc['compliant', 'PPV'] if 'PPV' in fairness_results.columns else 0,
                        'Predictive Equality': fairness_results.loc[unpriv_group, 'FPR'] - fairness_results.loc['compliant', 'FPR'] if 'FPR' in fairness_results.columns else 0,
                        'Equal Opportunity': fairness_results.loc[unpriv_group, 'TPR'] - fairness_results.loc['compliant', 'TPR'] if 'TPR' in fairness_results.columns else 0,
                        'Statistical Parity': fairness_results.loc[unpriv_group, 'STP'] - fairness_results.loc['compliant', 'STP'] if 'STP' in fairness_results.columns else 0
                    }

                    # Create the bar plot (like Image 1)
                    fig = go.Figure()

                    metrics_names = list(differences.keys())
                    metrics_values = list(differences.values())

                    # Add descriptions
                    descriptions = {
                        'Accuracy Equality': 'Difference in overall accuracy',
                        'Predictive Parity': 'Difference in Positive Predictive Values',
                        'Predictive Equality': 'Difference in False Positive Rates',
                        'Equal Opportunity': 'Difference in True Positive Rates',
                        'Statistical Parity': 'Difference in overall prediction rates'
                    }

                    # Create horizontal bar chart
                    fig.add_trace(go.Bar(
                        y=metrics_names,
                        x=metrics_values,
                        orientation='h',
                        marker_color='#2E4057',
                        text=[f'{v:.3f}' for v in metrics_values],
                        textposition='outside',
                        hovertemplate='<b>%{y}</b><br>' +
                                     '%{x:.3f}<br>' +
                                     '<extra></extra>'
                    ))

                    # Add vertical line at 0 (perfect fairness)
                    fig.add_vline(x=0, line_dash="dash", line_color="red",
                                 annotation_text="0 = Perfect Fairness")

                    # Update layout to match reference image
                    fig.update_layout(
                        title=dict(
                            text=f'Fairness Metrics - {dataset_name.upper()}<br>Subgroups: {criteria_name} Compliant vs Non-compliant',
                            font=dict(size=16)
                        ),
                        xaxis_title='Difference (Non-compliant - Compliant)',
                        yaxis_title='Metric',
                        template='plotly_white',
                        showlegend=False,
                        height=400,
                        margin=dict(l=200),
                        xaxis=dict(range=[-0.5, 0.5])
                    )

                    # Add metric descriptions as y-axis labels
                    fig.update_yaxes(
                        ticktext=[f"{name}<br><sub>{descriptions[name]}</sub>" for name in metrics_names],
                        tickvals=list(range(len(metrics_names)))
                    )

                    # Save the plot
                    html_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_bar_chart.html')
                    fig.write_html(html_file)
                    print(f"Saved: {html_file}")

                    # Save as image if kaleido is installed
                    try:
                        png_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_bar_chart.png')
                        fig.write_image(png_file, width=1200, height=600)
                        print(f"Saved PNG: {png_file}")
                    except Exception as e:
                        print(f"Could not save PNG (install kaleido with: pip install kaleido): {e}")

            except Exception as e:
                print(f"Could not generate fairness bar chart: {e}")

            try:
                # Method 2: Generate performance radar plot (like Image 2)
                print("Creating performance radar plot...")

                if len(fairness_results) > 1:
                    import plotly.graph_objects as go

                    fig = go.Figure()

                    # Define metrics to show in radar
                    radar_metrics = ['TPR', 'TNR', 'ACC', 'PPV']
                    radar_labels = ['True Positive Rate', 'True Negative Rate', 'Accuracy', 'Precision']

                    for group in fairness_results.index:
                        values = []
                        for metric in radar_metrics:
                            if metric in fairness_results.columns:
                                values.append(fairness_results.loc[group, metric])
                            else:
                                values.append(0)

                        # Close the radar
                        values.append(values[0])
                        labels_plot = radar_labels + [radar_labels[0]]

                        fig.add_trace(go.Scatterpolar(
                            r=values,
                            theta=labels_plot,
                            fill='toself',
                            name=f'{criteria_name} {group.replace("_", "-")}',
                            opacity=0.6
                        ))

                    fig.update_layout(
                        polar=dict(
                            radialaxis=dict(
                                visible=True,
                                range=[0, 1]
                            )
                        ),
                        title=f'Performance Metrics by Group - {dataset_name.upper()}',
                        showlegend=True,
                        template='plotly_white'
                    )

                    # Save the plot
                    html_file = os.path.join(dataset_dir, f'{dataset_name}_performance_radar.html')
                    fig.write_html(html_file)
                    print(f"Saved: {html_file}")

                    try:
                        png_file = os.path.join(dataset_dir, f'{dataset_name}_performance_radar.png')
                        fig.write_image(png_file, width=800, height=800)
                        print(f"Saved PNG: {png_file}")
                    except Exception as e:
                        print(f"Could not save PNG (install kaleido with: pip install kaleido): {e}")

            except Exception as e:
                print(f"Could not generate performance radar: {e}")

            try:
                # Method 3: Try to use DALEX's native plot method if available
                print("Attempting DALEX native plot...")

                # Monkey-patch plotly to fix the titlefont issue
                import plotly.graph_objs as go
                original_layout = go.Layout

                class PatchedLayout(original_layout):
                    def __init__(self, *args, **kwargs):
                        # Fix titlefont to title.font
                        if 'titlefont' in kwargs:
                            if 'title' not in kwargs:
                                kwargs['title'] = {}
                            if isinstance(kwargs['title'], str):
                                kwargs['title'] = {'text': kwargs['title']}
                            kwargs['title']['font'] = kwargs.pop('titlefont')
                        super().__init__(*args, **kwargs)

                go.Layout = PatchedLayout

                # Now try DALEX's plot
                try:
                    # Get the plot object from DALEX
                    plot = fobject.plot(show=False)

                    if plot is not None:
                        # Save as HTML
                        html_file = os.path.join(dataset_dir, f'{dataset_name}_dalex_fairness_check.html')
                        plot.write_html(html_file)
                        print(f"Saved DALEX fairness check: {html_file}")

                        # Save as PNG (requires kaleido)
                        try:
                            png_file = os.path.join(dataset_dir, f'{dataset_name}_dalex_fairness_check.png')
                            plot.write_image(png_file, width=1200, height=800)
                            print(f"Saved PNG: {png_file}")
                        except Exception as e:
                            print(f"Could not save PNG (install kaleido): {e}")

                except Exception as e:
                    print(f"DALEX plot generation failed: {e}")

                    # Alternative: Create fairness check visualization manually
                    print("Creating manual fairness check visualization...")
                    self.create_fairness_check_plot(fairness_results, fairness_check,
                                                   dataset_name, criteria_name, dataset_dir)

                # Restore original Layout
                go.Layout = original_layout

            except Exception as e:
                print(f"Could not generate DALEX native plot: {e}")
                # Create manual fairness check as fallback
                self.create_fairness_check_plot(fairness_results, fairness_check,
                                               dataset_name, criteria_name, dataset_dir)

            return fobject, fairness_results

        except Exception as e:
            print(f"Error creating DALEX fairness plots: {e}")
            traceback.print_exc()

            # Use fallback only if DALEX completely fails
            print("\nUsing fallback visualization...")
            return self.create_dalex_fairness_plots_fallback(explainer, protected, dataset_name, criteria_name)

    def create_fairness_check_plot(self, fairness_results, fairness_check,
                                   dataset_name, criteria_name, dataset_dir):
        """Create fairness check visualization like Image 4"""
        import plotly.graph_objects as go
        from plotly.subplots import make_subplots

        # Create subplots for independence, separation, sufficiency
        fig = make_subplots(
            rows=3, cols=1,
            subplot_titles=('independence', 'separation', 'sufficiency'),
            vertical_spacing=0.15,
            specs=[[{'type': 'bar'}], [{'type': 'bar'}], [{'type': 'bar'}]]
        )

        # Calculate fairness scores based on DALEX metrics
        if len(fairness_results) > 1:
            unpriv_group = [g for g in fairness_results.index if g != 'compliant'][0]

            # Independence: Statistical Parity (STP ratio)
            independence_score = fairness_results.loc[unpriv_group, 'STP'] if 'STP' in fairness_results.columns else 1.0

            # Separation: Equal Opportunity (TPR ratio)
            separation_score = fairness_results.loc[unpriv_group, 'TPR'] if 'TPR' in fairness_results.columns else 1.0

            # Sufficiency: Predictive Parity (PPV ratio)
            sufficiency_score = fairness_results.loc[unpriv_group, 'PPV'] if 'PPV' in fairness_results.columns else 1.0

            # Scale scores for visualization (like in reference image)
            independence_display = independence_score * 30
            separation_display = (separation_score - 1) * 5 + 1  # Scale around 1
            sufficiency_display = (sufficiency_score - 1) * 2 + 1  # Scale around 1

            # Add bars
            fig.add_trace(
                go.Bar(
                    x=[independence_display],
                    y=[f'{criteria_name}-noncompliant'],
                    orientation='h',
                    marker_color='teal',
                    showlegend=False,
                    text=[f'{independence_display:.1f}'],
                    textposition='outside'
                ),
                row=1, col=1
            )

            fig.add_trace(
                go.Bar(
                    x=[separation_display],
                    y=[f'{criteria_name}-noncompliant'],
                    orientation='h',
                    marker_color='teal',
                    showlegend=False,
                    text=[f'{separation_display:.1f}'],
                    textposition='outside'
                ),
                row=2, col=1
            )

            fig.add_trace(
                go.Bar(
                    x=[sufficiency_display],
                    y=[f'{criteria_name}-noncompliant'],
                    orientation='h',
                    marker_color='teal',
                    showlegend=False,
                    text=[f'{sufficiency_display:.1f}'],
                    textposition='outside'
                ),
                row=3, col=1
            )

            # Update layout
            fig.update_layout(
                title=dict(
                    text=f'Fairness Check - {dataset_name.upper()}',
                    font=dict(size=20, color='blue'),
                    x=0.5,
                    xanchor='center'
                ),
                height=600,
                showlegend=False,
                template='plotly_white',
                plot_bgcolor='rgba(255, 192, 203, 0.1)'  # Light pink background
            )

            # Update axes
            fig.update_xaxes(title_text='score', row=1, col=1, range=[0, 35])
            fig.update_xaxes(title_text='score', row=2, col=1, range=[0, 5])
            fig.update_xaxes(title_text='score', row=3, col=1, range=[0, 2])

            # Update y-axes
            for i in range(1, 4):
                fig.update_yaxes(title_text='subgroup', row=i, col=1)

            # Save the plot
            html_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_check.html')
            fig.write_html(html_file)
            print(f"Saved fairness check: {html_file}")

            # Save as PNG
            try:
                png_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_check.png')
                fig.write_image(png_file, width=800, height=600)
                print(f"Saved PNG: {png_file}")
            except Exception as e:
                print(f"Could not save PNG (install kaleido): {e}")

    def create_dalex_fairness_plots_fallback(self, explainer, protected, dataset_name, criteria_name):
        """Fallback method for creating fairness visualizations using matplotlib"""
        print("Using fallback method for fairness visualizations...")

        dataset_dir = os.path.join(self.output_dir, dataset_name)

        # Get predictions and true values
        y_true = explainer.y
        y_pred = explainer.predict(explainer.data)

        # Ensure we have two groups
        unique_groups = protected.unique()
        if len(unique_groups) < 2:
            print("Cannot perform fairness analysis with only one group")
            return None, None

        # For regression, convert to binary for fairness metrics
        if explainer.model_type == 'regression':
            y_pred_binary = (y_pred > np.median(y_pred)).astype(int)
            y_true_binary = (y_true > np.median(y_true)).astype(int)
        else:
            y_pred_binary = (y_pred > 0.5).astype(int)
            y_true_binary = y_true

        # Calculate metrics for each group
        metrics_by_group = {}

        for group in unique_groups:
            group_mask = (protected == group).values

            if group_mask.sum() > 0:
                from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

                y_true_group = y_true_binary[group_mask]
                y_pred_group = y_pred_binary[group_mask]

                # Calculate confusion matrix
                tn, fp, fn, tp = confusion_matrix(y_true_group, y_pred_group).ravel() if len(np.unique(y_true_group)) > 1 else (0, 0, 0, 0)

                # Calculate metrics
                metrics_by_group[group] = {
                    'accuracy': accuracy_score(y_true_group, y_pred_group),
                    'precision': precision_score(y_true_group, y_pred_group, zero_division=0),
                    'recall': recall_score(y_true_group, y_pred_group, zero_division=0),
                    'selection_rate': np.mean(y_pred_group),
                    'tpr': tp / (tp + fn) if (tp + fn) > 0 else 0,
                    'fpr': fp / (fp + tn) if (fp + tn) > 0 else 0,
                    'tnr': tn / (tn + fp) if (tn + fp) > 0 else 0,
                    'group_size': group_mask.sum()
                }

        # Calculate fairness metrics (using first two groups)
        groups = list(metrics_by_group.keys())[:2]
        if len(groups) == 2:
            fairness_metrics = {
                'Statistical_Parity': metrics_by_group[groups[1]]['selection_rate'] - metrics_by_group[groups[0]]['selection_rate'],
                'Equal_Opportunity': metrics_by_group[groups[1]]['tpr'] - metrics_by_group[groups[0]]['tpr'],
                'Predictive_Equality': metrics_by_group[groups[1]]['fpr'] - metrics_by_group[groups[0]]['fpr'],
                'Accuracy_Equality': metrics_by_group[groups[1]]['accuracy'] - metrics_by_group[groups[0]]['accuracy'],
                'Disparate_Impact': metrics_by_group[groups[1]]['selection_rate'] / (metrics_by_group[groups[0]]['selection_rate'] + 1e-10)
            }
        else:
            fairness_metrics = {}

        # Create comprehensive visualization
        fig = plt.figure(figsize=(16, 12))

        # 1. Fairness Metrics Bar Chart
        ax1 = plt.subplot(3, 3, 1)
        if fairness_metrics:
            metrics_names = list(fairness_metrics.keys())
            metrics_values = list(fairness_metrics.values())
            colors = ['red' if abs(v) > 0.1 else 'green' for v in metrics_values]
            ax1.barh(metrics_names, metrics_values, color=colors, alpha=0.7)
            ax1.axvline(0, color='black', linestyle='--', alpha=0.5)
            ax1.set_xlabel('Difference (0 = Fair)')
            ax1.set_title('Fairness Metrics', fontweight='bold')
            ax1.grid(True, alpha=0.3)

        # 2. Group Performance Comparison
        ax2 = plt.subplot(3, 3, 2)
        group_names = list(metrics_by_group.keys())
        x = np.arange(len(['Accuracy', 'Precision', 'Recall']))
        width = 0.35

        for i, group in enumerate(group_names[:2]):
            values = [
                metrics_by_group[group]['accuracy'],
                metrics_by_group[group]['precision'],
                metrics_by_group[group]['recall']
            ]
            ax2.bar(x + i*width, values, width, label=f'{group} (n={metrics_by_group[group]["group_size"]})')

        ax2.set_ylabel('Score')
        ax2.set_title('Performance by Group', fontweight='bold')
        ax2.set_xticks(x + width/2)
        ax2.set_xticklabels(['Accuracy', 'Precision', 'Recall'])
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim([0, 1])

        # 3. Selection Rate Comparison
        ax3 = plt.subplot(3, 3, 3)
        selection_rates = [metrics_by_group[g]['selection_rate'] for g in group_names]
        ax3.bar(group_names, selection_rates, color=['blue', 'orange'])
        ax3.set_ylabel('Selection Rate')
        ax3.set_title('Selection Rate by Group', fontweight='bold')
        ax3.set_ylim([0, 1])
        ax3.grid(True, alpha=0.3)

        # 4. ROC Space
        ax4 = plt.subplot(3, 3, 4)
        for group in group_names[:2]:
            fpr = metrics_by_group[group]['fpr']
            tpr = metrics_by_group[group]['tpr']
            ax4.scatter(fpr, tpr, s=100, label=group)
        ax4.plot([0, 1], [0, 1], 'k--', alpha=0.5)
        ax4.set_xlabel('False Positive Rate')
        ax4.set_ylabel('True Positive Rate')
        ax4.set_title('ROC Space', fontweight='bold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        ax4.set_xlim([-0.05, 1.05])
        ax4.set_ylim([-0.05, 1.05])

        # 5. Confusion Matrix Comparison
        for idx, group in enumerate(group_names[:2]):
            ax = plt.subplot(3, 3, 5 + idx)
            group_mask = (protected == group).values

            if group_mask.sum() > 0:
                cm = confusion_matrix(y_true_binary[group_mask], y_pred_binary[group_mask])
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
                ax.set_title(f'Confusion Matrix - {group}', fontweight='bold')
                ax.set_xlabel('Predicted')
                ax.set_ylabel('Actual')

        # 7. Sample Distribution
        ax7 = plt.subplot(3, 3, 7)
        sizes = [metrics_by_group[g]['group_size'] for g in group_names]
        ax7.pie(sizes, labels=group_names, autopct='%1.1f%%', startangle=90)
        ax7.set_title('Group Distribution', fontweight='bold')

        # 8. Metric Differences Table
        ax8 = plt.subplot(3, 3, 8)
        ax8.axis('tight')
        ax8.axis('off')

        if len(groups) == 2:
            table_data = []
            for metric in ['accuracy', 'precision', 'recall', 'tpr', 'fpr']:
                diff = metrics_by_group[groups[1]][metric] - metrics_by_group[groups[0]][metric]
                table_data.append([
                    metric.capitalize(),
                    f"{metrics_by_group[groups[0]][metric]:.3f}",
                    f"{metrics_by_group[groups[1]][metric]:.3f}",
                    f"{diff:+.3f}"
                ])

            table = ax8.table(cellText=table_data,
                            colLabels=[' Metric', groups[0][:15], groups[1][:15], 'Difference'],
                            cellLoc='center',
                            loc='center')
            table.auto_set_font_size(False)
            table.set_fontsize(9)
            table.scale(1.2, 1.5)

        # 9. Fairness Score Summary
        ax9 = plt.subplot(3, 3, 9)
        ax9.axis('off')

        if fairness_metrics:
            # Calculate overall fairness score
            fairness_violations = sum([
                abs(fairness_metrics.get('Statistical_Parity', 0)) > 0.1,
                abs(fairness_metrics.get('Equal_Opportunity', 0)) > 0.1,
                abs(fairness_metrics.get('Predictive_Equality', 0)) > 0.1,
                abs(fairness_metrics.get('Accuracy_Equality', 0)) > 0.1,
                fairness_metrics.get('Disparate_Impact', 1) < 0.8 or fairness_metrics.get('Disparate_Impact', 1) > 1.25
            ])

            fairness_score = 1 - (fairness_violations / 5)
            color = 'green' if fairness_score > 0.8 else 'orange' if fairness_score > 0.6 else 'red'

            ax9.text(0.5, 0.7, 'Overall Fairness Score', ha='center', fontsize=14, fontweight='bold')
            ax9.text(0.5, 0.4, f'{fairness_score:.1%}', ha='center', fontsize=24, color=color, fontweight='bold')
            ax9.text(0.5, 0.1, f'{5-fairness_violations}/5 metrics within threshold', ha='center', fontsize=10)

        plt.suptitle(f'Fairness Analysis - {dataset_name.upper()}\nCriteria: {criteria_name}',
                    fontsize=16, fontweight='bold')
        plt.tight_layout()

        # Save plot
        plot_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_analysis.png')
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved comprehensive fairness analysis to: {plot_file}")

        # Save metrics to file
        metrics_file = os.path.join(dataset_dir, f'{dataset_name}_fairness_metrics.json')
        import json
        with open(metrics_file, 'w') as f:
            json.dump({
                'group_metrics': metrics_by_group,
                'fairness_metrics': fairness_metrics,
                'criteria': criteria_name
            }, f, indent=2)
        print(f"Saved metrics to: {metrics_file}")

        # Create results DataFrame
        fairness_results = pd.DataFrame({
            'metric': list(fairness_metrics.keys()),
            'value': list(fairness_metrics.values())
        })

        return None, fairness_results

    def create_functional_group_heatmap(self, X_test, y_test, y_pred, dataset_name):
        """Create functional group fairness heatmap"""
        print(f"\nGenerating functional group heatmap for {dataset_name}...")

        dataset_dir = os.path.join(self.output_dir, dataset_name)

        # Calculate fairness metrics for functional groups
        fg_metrics = {}

        for col in X_test.columns:
            if col.endswith('_count'):
                # Use median as threshold
                threshold = X_test[col].median()
                privileged = X_test[col] <= threshold

                # Calculate metrics
                if privileged.sum() > 0 and (~privileged).sum() > 0:
                    # Binary predictions
                    y_pred_binary = (y_pred > 0.5).astype(int) if y_pred.max() <= 1 else (y_pred > np.median(y_pred)).astype(int)

                    # Disparate Impact
                    prob_priv = np.mean(y_pred_binary[privileged])
                    prob_unpriv = np.mean(y_pred_binary[~privileged])
                    di = prob_unpriv / (prob_priv + 1e-6)

                    # Statistical Parity Difference
                    spd = prob_unpriv - prob_priv

                    # Equal Opportunity Difference
                    pos_mask = y_test == 1
                    if (privileged & pos_mask).sum() > 0 and (~privileged & pos_mask).sum() > 0:
                        tpr_priv = np.mean(y_pred_binary[privileged & pos_mask] == 1)
                        tpr_unpriv = np.mean(y_pred_binary[~privileged & pos_mask] == 1)
                        eod = tpr_unpriv - tpr_priv
                    else:
                        eod = 0.0

                    fg_name = col.replace('_count', '')
                    fg_metrics[fg_name] = {'DI': di, 'SPD': spd, 'EOD': eod}

        if fg_metrics:
            # Create DataFrame
            metrics_df = pd.DataFrame(fg_metrics).T

            # Create heatmap
            plt.figure(figsize=(10, max(8, len(metrics_df) * 0.4)))

            sns.heatmap(metrics_df,
                       annot=True,
                       fmt='.2f',
                       cmap='RdBu_r',
                       center=0 if 'SPD' in metrics_df.columns else 1,
                       cbar_kws={'label': 'Metric Value'},
                       vmin=-1, vmax=2)

            plt.title(f'Functional Group Fairness - {dataset_name.upper()}', fontsize=14, fontweight='bold')
            plt.xlabel('Fairness Metrics', fontsize=12)
            plt.ylabel('Functional Groups', fontsize=12)
            plt.tight_layout()

            # Save plot
            heatmap_file = os.path.join(dataset_dir, f'{dataset_name}_functional_group_heatmap.png')
            plt.savefig(heatmap_file, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"Saved: {heatmap_file}")

            return fg_metrics

        return None

    def analyze_dataset(self, dataset_name):
        """Complete fairness analysis for a single dataset"""
        print(f"\n{'='*60}")
        print(f"Analyzing {dataset_name.upper()}")
        print(f"{'='*60}")

        # Load dataset
        data = self.load_dataset(dataset_name)
        if data is None:
            return None

        # Extract features
        print("Extracting molecular features...")
        train_features, train_valid_idx = self.extract_features(data['train_smiles'])
        test_features, test_valid_idx = self.extract_features(data['test_smiles'])

        # Filter labels
        if data['task_type'] == 'classification':
            train_y = data['train_y_binary'][train_valid_idx]
            test_y = data['test_y_binary'][test_valid_idx]
        else:
            train_y = data['train_y'][train_valid_idx]
            test_y = data['test_y'][test_valid_idx]
            # For regression fairness, use binary version
            test_y_binary = data['test_y_binary'][test_valid_idx]

        # Standardize features
        scaler = StandardScaler()
        X_train = scaler.fit_transform(train_features.fillna(0))
        X_test = scaler.transform(test_features.fillna(0))

        # Train model
        print("Training model...")
        model = self.train_model(X_train, train_y, data['task_type'])

        # Get predictions
        if data['task_type'] == 'classification':
            y_pred = model.predict_proba(X_test)[:, 1]
        else:
            y_pred = model.predict(X_test)

        # Create DALEX explainer
        print("Creating DALEX explainer...")
        wrapper = SimpleWrapper(y_pred)

        # Prepare data for explainer
        X_test_df = pd.DataFrame(X_test, columns=train_features.columns)

        # Ensure y is the correct type and shape
        if data['task_type'] == 'classification':
            y_for_explainer = test_y
        else:
            # For regression fairness, use binary version
            y_for_explainer = test_y_binary

        # Create explainer with explicit parameters
        explainer = dx.Explainer(
            model=wrapper,
            data=X_test_df,
            y=y_for_explainer,
            model_type=data['task_type'],
            label=f"{dataset_name.upper()} Model",
            verbose=False
        )

        # Get compliance criteria
        criteria = self.compliance_criteria.get(dataset_name)
        if criteria:
            compliant = criteria['check'](test_features)
            criteria_name = criteria['name']
        else:
            # Default to RO5
            compliant = (test_features['MW'] <= 500) & (test_features['LogP'] <= 5)
            criteria_name = 'RO5_compliant'

        # Create protected attribute
        protected = pd.Series('non_compliant', index=test_features.index)
        protected[compliant] = 'compliant'

        print(f"\nGroup distribution for {criteria_name}:")
        print(protected.value_counts())
        print(f"Percentage: {protected.value_counts(normalize=True).round(3) * 100}")

        # Check if we have both groups for fairness analysis
        if len(protected.unique()) < 2:
            print(f"\n⚠ Warning: Only one group found for {dataset_name}. Creating artificial split for fairness analysis...")

            # Create artificial split based on median of a key feature
            if 'MW' in test_features.columns:
                median_mw = test_features['MW'].median()
                protected = pd.Series('below_median_MW', index=test_features.index)
                protected[test_features['MW'] > median_mw] = 'above_median_MW'
                criteria_name = 'MW_median_split'
            else:
                # Use random split as last resort
                mid_point = len(protected) // 2
                protected = pd.Series('group_A', index=test_features.index)
                protected.iloc[mid_point:] = 'group_B'
                criteria_name = 'random_split'

            print(f"Created artificial groups based on {criteria_name}:")
            print(protected.value_counts())

        # Generate DALEX fairness plots
        fairness_object, fairness_results = self.create_dalex_fairness_plots(
            explainer, protected, dataset_name, criteria_name
        )

        # Generate functional group heatmap
        y_test_for_fg = test_y_binary if data['task_type'] == 'regression' else test_y
        fg_metrics = self.create_functional_group_heatmap(
            test_features, y_test_for_fg, y_pred, dataset_name
        )

        # Calculate performance metrics
        if data['task_type'] == 'classification':
            from sklearn.metrics import roc_auc_score
            auc = roc_auc_score(test_y, y_pred)
            print(f"\nTest AUC: {auc:.4f}")
        else:
            from sklearn.metrics import mean_squared_error, r2_score
            rmse = np.sqrt(mean_squared_error(test_y, y_pred))
            r2 = r2_score(test_y, y_pred)
            print(f"\nTest RMSE: {rmse:.4f}")
            print(f"Test R²: {r2:.4f}")

        return {
            'dataset': dataset_name,
            'fairness_object': fairness_object,
            'fairness_results': fairness_results,
            'functional_group_metrics': fg_metrics,
            'task_type': data['task_type']
        }

    def run_all_analyses(self, datasets=None):
        """Run fairness analysis for all datasets"""
        if datasets is None:
            datasets = ['bace', 'bbbp', 'clintox', 'esol', 'qm9']

        all_results = {}

        print(f"\n{'='*60}")
        print("STARTING MULTI-DATASET FAIRNESS ANALYSIS")
        print(f"{'='*60}")
        print(f"Datasets to analyze: {datasets}")

        for dataset_name in datasets:
            try:
                results = self.analyze_dataset(dataset_name)
                if results:
                    all_results[dataset_name] = results
                    print(f"\n✓ Successfully analyzed {dataset_name}")
                else:
                    print(f"\n✗ Failed to analyze {dataset_name}")

            except Exception as e:
                print(f"\n✗ Error analyzing {dataset_name}: {e}")
                traceback.print_exc()
                continue

        # Create summary
        self.create_summary_report(all_results)

        return all_results

    def create_summary_report(self, all_results):
        """Create summary report across all datasets"""
        print(f"\n{'='*60}")
        print("FAIRNESS ANALYSIS SUMMARY")
        print(f"{'='*60}")

        summary_file = os.path.join(self.output_dir, 'fairness_summary.txt')

        with open(summary_file, 'w') as f:
            f.write("MULTI-DATASET FAIRNESS ANALYSIS SUMMARY\n")
            f.write("="*60 + "\n\n")

            for dataset_name, results in all_results.items():
                f.write(f"\n{dataset_name.upper()}\n")
                f.write("-"*30 + "\n")
                f.write(f"Task Type: {results['task_type']}\n")

                if results['fairness_results'] is not None:
                    f.write("Fairness Metrics: Available\n")
                else:
                    f.write("Fairness Metrics: Failed\n")

                if results['functional_group_metrics'] is not None:
                    f.write(f"Functional Groups Analyzed: {len(results['functional_group_metrics'])}\n")
                else:
                    f.write("Functional Groups: Not analyzed\n")

                f.write("\n")

        print(f"\nSummary saved to: {summary_file}")

        # Print summary to console
        print("\nDatasets Successfully Analyzed:")
        for dataset_name in all_results.keys():
            print(f"  ✓ {dataset_name.upper()}")

        print(f"\nAll results saved to: {self.output_dir}")
        print("\nGenerated files for each dataset:")
        print("  - DALEX fairness plots (HTML)")
        print("  - Fairness metrics (TXT)")
        print("  - Functional group heatmap (PNG)")

# Main execution function
def main():
    """Main function to run fairness analysis for all datasets"""

    # Initialize analyzer
    analyzer = MultiDatasetFairnessAnalyzer(output_dir='./DALEX-Fairness-Results')

    # Define datasets to analyze
    # datasets = ['bace', 'bbbp', 'clintox', 'esol', 'qm9']
    datasets = ['esol']

    # Run analysis for all datasets
    results = analyzer.run_all_analyses(datasets)

    print(f"\n{'='*60}")
    print("ALL ANALYSES COMPLETE!")
    print(f"{'='*60}")

    print("\nTo view results:")
    print("1. Check the DALEX-Fairness-Results folder")
    print("2. Open HTML files in a browser for interactive plots")
    print("3. View PNG files for static visualizations")
    print("4. Read TXT files for detailed metrics")

    return results

if __name__ == "__main__":
    results = main()