In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm
from datetime import datetime
from typing import Tuple, List, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import Dataset, Data, DataLoader
from torch_geometric.nn import (
    GCNConv,
    GINConv,
    global_add_pool,
    global_mean_pool,
    global_max_pool,
    MessagePassing
)

import deepchem as dc
from deepchem.molnet import load_qm7

from rdkit import Chem
from rdkit.Chem import (
    RemoveHs,
    AllChem,
    Descriptors,
    ChemicalFeatures,
    Draw
)
from rdkit import RDLogger

import dalex as dx
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.metrics import (
    mean_squared_error,
    r2_score,
    mean_absolute_error
)

from tqdm import tqdm  # For progress bars
import numpy as np
import pandas as pd
import dalex as dx
from rdkit import Chem
from rdkit.Chem import Descriptors

import traceback

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')


In [None]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC,
            Chem.rdchem.BondType.DATIVE
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []

            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)

            # LogP contribution
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)

            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])

            return atom_feat, phys_feat

        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []

        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)

        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)

        return x, phys

    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        row, col, edge_feat = [], [], []

        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()

                # Add edges in both directions
                row += [start, end]
                col += [end, start]

                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())

                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]

                edge_feat.extend([feat, feat])

            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)

        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            data = Data(
                x_cat=x_cat,
                x_phys=x_phys,
                edge_index=edge_index,
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

            # Store the original SMILES string
            data._store.smiles = smiles

            return data

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [None]:
class GraphDiscriminator(nn.Module):
    """Reimplementation of original discriminator architecture"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()

        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)

        # Projection head
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, data):
        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()
        batch = data.batch

        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Global pooling
        x = global_mean_pool(x, batch)

        # Projection
        x = self.projection(x)

        return x

# Load Encoder
def load_encoder(model_path, device='cpu'):
    """Load trained encoder"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    encoder.eval()
    return encoder.to(device)

# Load Embeddings
def load_embeddings(filepath):
    """Load embeddings and labels"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['embeddings'], data['labels']

# Paths from your saved model
encoder_path = '../checkpoints/encoders/final_encoder_20250216_111050.pt'
embedding_path = '../embeddings/final_embeddings_20250216_111005.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load encoder and embeddings
encoder = load_encoder(encoder_path, device)
embeddings, graph_data = load_embeddings(embedding_path)


In [None]:
class GINLayer(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, edge_dim: int):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels + edge_dim, out_channels),
            torch.nn.BatchNorm1d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.edge_encoder = torch.nn.Linear(edge_dim, edge_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.edge_encoder(edge_attr)
        row, col = edge_index
        out = torch.cat([x[row], edge_embedding], dim=1)
        out = self.mlp(out)

        # Using PyTorch's native scatter_add
        output = torch.zeros_like(x)
        output.scatter_add_(0, col.unsqueeze(-1).expand(-1, out.size(-1)), out)
        return output

class HybridGNNRegressor(torch.nn.Module):
    def __init__(self, encoder, node_dim: int, edge_dim: int, hidden_dim: int = 128,
                 num_layers: int = 3, num_tasks: int = 1, dropout: float = 0.3):
        super().__init__()
        self.encoder = encoder
        self.encoder.eval()  # Freeze encoder

        # Initial node embedding (same as before)
        self.node_embedding = torch.nn.Sequential(
            torch.nn.Linear(node_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout)
        )

        # GIN layers (same as before)
        self.gin_layers = torch.nn.ModuleList()
        for _ in range(num_layers):
            gin_layer = GINLayer(hidden_dim, hidden_dim, edge_dim)
            self.gin_layers.append(gin_layer)

        # Pooling attention (same as before)
        self.pool_attention = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 3),
            torch.nn.Softmax(dim=1)
        )

        # Combine pretrained embeddings with GNN output (same as before)
        encoder_out_dim = encoder.projection[-1].out_features
        self.combination_layer = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim + encoder_out_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout)
        )

        # Regression head - modified for regression output
        self.regression_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim // 2, hidden_dim // 4),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim // 4, num_tasks)  # No activation for regression
        )

    def forward(self, data):
        # Forward pass remains the same until the final layer
        with torch.no_grad():
            pretrained_emb = self.encoder(data)

        x = torch.cat([data.x_cat.float(), data.x_phys], dim=-1)
        x = self.node_embedding(x)
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        batch = data.batch

        for gin_layer in self.gin_layers:
            x_new = gin_layer(x, edge_index, edge_attr)
            x = x + x_new

        pool_attention = self.pool_attention(x)

        x_mean = global_mean_pool(x * pool_attention[:, 0:1], batch)
        x_max = global_max_pool(x * pool_attention[:, 1:2], batch)
        x_sum = global_add_pool(x * pool_attention[:, 2:3], batch)

        x_pooled = x_mean + x_max + x_sum

        combined = self.combination_layer(
            torch.cat([x_pooled, pretrained_emb], dim=1)
        )

        # Regression output
        return self.regression_head(combined)

def train_hybrid_model_regression(model, train_loader, val_loader, device,
                                num_epochs=100, lr=1e-3, weight_decay=1e-4):
    print("Initializing training...")
    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of validation batches: {len(val_loader)}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                patience=5, verbose=True)

    # Use MSE Loss for regression
    criterion = torch.nn.MSELoss()
    best_val_rmse = float('inf')
    best_model = None
    patience = 10
    patience_counter = 0

    print("\nStarting training loop...")
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_predictions = []
        train_targets = []

        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            optimizer.zero_grad()

            try:
                outputs = model(batch)
                targets = batch.y.view(-1, 1)
                loss = criterion(outputs, targets)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                train_loss += loss.item()
                train_predictions.extend(outputs.detach().cpu().numpy())
                train_targets.extend(targets.cpu().numpy())

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue

        # Validation
        model.eval()
        val_loss = 0
        val_predictions = []
        val_targets = []

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                try:
                    outputs = model(batch)
                    targets = batch.y.view(-1, 1)
                    loss = criterion(outputs, targets)

                    val_loss += loss.item()
                    val_predictions.extend(outputs.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())

                except Exception as e:
                    print(f"Error in validation: {str(e)}")
                    continue

        # Calculate metrics
        try:
            train_rmse = np.sqrt(mean_squared_error(train_targets, train_predictions))
            train_r2 = r2_score(train_targets, train_predictions)
            train_mae = mean_absolute_error(train_targets, train_predictions)

            val_rmse = np.sqrt(mean_squared_error(val_targets, val_predictions))
            val_r2 = r2_score(val_targets, val_predictions)
            val_mae = mean_absolute_error(val_targets, val_predictions)

            print(f'\nEpoch {epoch+1}/{num_epochs}:')
            print(f'Train Loss: {train_loss/len(train_loader):.4f}')
            print(f'Train RMSE: {train_rmse:.4f}')
            print(f'Train R2: {train_r2:.4f}')
            print(f'Train MAE: {train_mae:.4f}')
            print(f'Val Loss: {val_loss/len(val_loader):.4f}')
            print(f'Val RMSE: {val_rmse:.4f}')
            print(f'Val R2: {val_r2:.4f}')
            print(f'Val MAE: {val_mae:.4f}')
            print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

            # Save best model based on validation RMSE
            if val_rmse < best_val_rmse:
                best_val_rmse = val_rmse
                best_model = {
                    'state_dict': model.state_dict(),
                    'val_rmse': val_rmse,
                    'epoch': epoch
                }
                patience_counter = 0
                print("New best model saved!")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f'Early stopping triggered after epoch {epoch+1}')
                    break

            # Update learning rate based on validation RMSE
            scheduler.step(val_rmse)

        except Exception as e:
            print(f"Error calculating metrics: {str(e)}")
            continue

    return best_model

In [None]:
class QM7Dataset(Dataset):
    """Custom PyTorch Geometric dataset for ClinTox with better error handling"""
    def __init__(self, smiles_list, labels, feature_extractor):
        super().__init__()
        self.smiles_list = []
        self.labels = []
        self.feature_extractor = feature_extractor
        self.processed_data = []

        # Process all molecules
        print("Processing molecules...")
        for idx, smiles in enumerate(tqdm(smiles_list)):
            data = self.feature_extractor.process_molecule(smiles)
            if data is not None:
                self.processed_data.append(data)
                self.smiles_list.append(smiles)
                self.labels.append(labels[idx])

        # Convert labels to tensor
        self.labels = torch.tensor(self.labels, dtype=torch.float)
        print(f"Successfully processed {len(self.processed_data)} out of {len(smiles_list)} molecules")

    def len(self):
        return len(self.processed_data)

    def get(self, idx):
        data = self.processed_data[idx]
        data.y = self.labels[idx].view(1, -1)  # Reshape to [1, num_tasks]
        # Store SMILES as a property of the Data object
        data._store.smiles = self.smiles_list[idx]

        return data

class RegressionHead(nn.Module):
    """Enhanced regression head with residual connections"""
    def __init__(self, input_dim: int, hidden_dims: List[int] = [256, 128, 64],
                 num_tasks: int = 1, dropout_rate: float = 0.3):
        super().__init__()

        self.input_bn = nn.BatchNorm1d(input_dim)

        # Shared layers with residual connections
        self.shared_layers = nn.ModuleList()
        current_dim = input_dim

        for hidden_dim in hidden_dims:
            self.shared_layers.append(nn.Sequential(
                nn.Linear(current_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ))
            if current_dim == hidden_dim:  # Add residual connection
                self.shared_layers.append(lambda x: x)
            current_dim = hidden_dim

        # Regression-specific layers
        self.regressor = nn.Sequential(
            nn.Linear(hidden_dims[-1], hidden_dims[-1] // 2),
            nn.BatchNorm1d(hidden_dims[-1] // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dims[-1] // 2, num_tasks)  # Direct regression output
        )

    def forward(self, x):
        x = self.input_bn(x)

        # Shared feature extraction
        for layer in self.shared_layers:
            if isinstance(layer, nn.Sequential):
                x = x + layer(x) if x.size() == layer(x).size() else layer(x)
            else:
                x = layer(x)

        # Regression prediction
        return self.regressor(x)

def load_pretrained_encoder(encoder_path: str, device: str = 'cuda'):
    """Load the pretrained encoder"""
    checkpoint = torch.load(encoder_path, map_location=device)
    model_info = checkpoint['model_info']

    encoder = GraphDiscriminator(
        node_dim=model_info['node_dim'],
        edge_dim=model_info['edge_dim'],
        hidden_dim=model_info['hidden_dim'],
        output_dim=model_info['output_dim']
    ).to(device)

    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder, model_info

def train_regressor_with_scheduler(encoder, regression_head, train_loader, val_loader,
                                 device, num_epochs=100):
    """Enhanced training function for regression with learning rate scheduling"""
    optimizer = torch.optim.AdamW(regression_head.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                         factor=0.5, patience=5,
                                                         verbose=True)
    criterion = nn.MSELoss()  # Use MSE loss for regression

    # Freeze encoder
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False

    best_val_rmse = float('inf')
    best_model = None
    patience = 10
    patience_counter = 0

    for epoch in range(num_epochs):
        # Training
        regression_head.train()
        train_loss = 0
        y_true_train = []
        y_pred_train = []

        for batch in train_loader:
            batch = batch.to(device)
            with torch.no_grad():
                embeddings = encoder(batch)

            outputs = regression_head(embeddings)
            targets = batch.y.view(-1, 1)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(regression_head.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            y_true_train.extend(targets.cpu().numpy())
            y_pred_train.extend(outputs.detach().cpu().numpy())

        # Validation
        regression_head.eval()
        val_loss = 0
        y_true_val = []
        y_pred_val = []

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                embeddings = encoder(batch)
                outputs = regression_head(embeddings)
                targets = batch.y.view(-1, 1)
                loss = criterion(outputs, targets)

                val_loss += loss.item()
                y_true_val.extend(targets.cpu().numpy())
                y_pred_val.extend(outputs.cpu().numpy())

        # Calculate regression metrics
        train_rmse = np.sqrt(mean_squared_error(y_true_train, y_pred_train))
        train_r2 = r2_score(y_true_train, y_pred_train)
        train_mae = mean_absolute_error(y_true_train, y_pred_train)

        val_rmse = np.sqrt(mean_squared_error(y_true_val, y_pred_val))
        val_r2 = r2_score(y_true_val, y_pred_val)
        val_mae = mean_absolute_error(y_true_val, y_pred_val)

        # Update learning rate based on validation RMSE
        scheduler.step(val_rmse)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Train RMSE: {train_rmse:.4f}')
        print(f'Train R²: {train_r2:.4f}')
        print(f'Train MAE: {train_mae:.4f}')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}')
        print(f'Val RMSE: {val_rmse:.4f}')
        print(f'Val R²: {val_r2:.4f}')
        print(f'Val MAE: {val_mae:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Save best model based on validation RMSE
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_model = regression_head.state_dict()
            patience_counter = 0
            print("New best model saved!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after epoch {epoch+1}')
                break

    return best_model, best_val_rmse


In [None]:
class SimpleWrapper:
    """Wrapper class for model predictions in DALEX"""
    def __init__(self, model, device, preprocessed_data=None):
        print("\nInitializing SimpleWrapper:")
        print(f"Model type: {type(model)}")
        print(f"Device: {device}")
        print(f"Initial preprocessed_data: {type(preprocessed_data) if preprocessed_data is not None else None}")

        self.model = model
        self.device = device
        self.preprocessed_data = preprocessed_data

    def predict(self, X, *args):
        """Convert pandas/numpy input to model predictions"""
        print("\nSimpleWrapper predict called:")
        print(f"Input type: {type(X)}")
        print(f"Input shape: {X.shape if hasattr(X, 'shape') else 'No shape'}")
        print(f"Additional args: {args}")

        try:
            if isinstance(X, SimpleWrapper):
                print("Input is SimpleWrapper, returning preprocessed_data")
                return self.preprocessed_data

            if isinstance(X, (pd.DataFrame, np.ndarray)):
                print(f"Input is {type(X)}, returning slice of preprocessed_data")
                if self.preprocessed_data is not None:
                    return self.preprocessed_data[:len(X)]
                else:
                    print("Warning: No preprocessed_data available")
                    return np.zeros(len(X))

            print(f"Unhandled input type: {type(X)}")
            return np.zeros(1)

        except Exception as e:
            print(f"Error in predict: {str(e)}")
            traceback.print_exc()
            if hasattr(X, '__len__'):
                return np.zeros(len(X))
            return np.zeros(1)

class MolecularRegressionFairnessAnalyzer:
    def __init__(self, model, test_loader, device):
        self.model = model
        self.test_loader = test_loader
        self.device = device
        self.predictions = None
        self.true_values = None
        self.mol_features = None

    def extract_molecular_features(self, smiles: str) -> Dict:
        """Extract QM7 solubility-relevant features and functional groups"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None

        # QM7-relevant molecular descriptors
        logp = Descriptors.MolLogP(mol)          # Partition coefficient
        tpsa = Descriptors.TPSA(mol)             # Topological Polar Surface Area

        mw = Descriptors.ExactMolWt(mol)         # Molecular Weight (may affect solubility)
        rotatable_bonds = Descriptors.NumRotatableBonds(mol)  # Flexibility affects solubility
        num_atoms = mol.GetNumAtoms()            # Total number of atoms
        num_heavy_atoms = mol.GetNumHeavyAtoms() # Number of heavy atoms (not H)
        rotatable_bonds = Descriptors.NumRotatableBonds(mol)
        # Get atomic symbols to check allowed atoms in QM7
        atomic_symbols = set(atom.GetSymbol() for atom in mol.GetAtoms())

        # Functional group analysis with SMARTS patterns
        functional_groups = {
            'primary_amine': '[NX3H2]',
            'secondary_amine': '[NX3H1]',
            'tertiary_amine': '[NX3H0]',
            'amide': '[NX3][CX3](=[OX1])',
            'chloro': '[Cl]',
            'bromo': '[Br]',
            'fluoro': '[F]',
            'iodo': '[I]',
            'alcohol': '[OH]',
            'phenol': '[OH1][c]',
            'ether': '[OD2]([#6])[#6]',
            'ester': '[#6][CX3](=O)[OX2H0][#6]',
            'carbonyl': '[CX3]=O',
            'aldehyde': '[CX3H1](=O)',
            'ketone': '[#6][CX3](=O)[#6]',
            'carboxyl': '[CX3](=O)[OX2H1]',
            'acyl_halide': '[CX3](=[OX1])[F,Cl,Br,I]',
            'phosphate': '[$(P(=[OX1])([OX2H,OX1-])([OX2H,OX1-])[OX2H,OX1-])]',
            'sulfate': '[$(S(=O)(=O)(O)[O-,OH])]',
            'sulfonamide': '[#16X4]([NX3])(=[OX1])(=[OX1])',
            'nitro': '[N+](=O)[O-]'
        }

        # Calculate functional group counts
        fg_counts = {}
        for name, smarts in functional_groups.items():
            pattern = Chem.MolFromSmarts(smarts)
            if pattern:
                matches = mol.GetSubstructMatches(pattern)
                fg_counts[f'{name}_count'] = len(matches)

        # Create features dictionary with QM7-specific criteria
        features = {
            'MW': mw,
            'NumAtoms': num_atoms,
            'NumHeavyAtoms': num_heavy_atoms,
            'RotBonds': rotatable_bonds,
            'QM7_violations': sum([
                num_heavy_atoms > 23,     # QM7 limit on heavy atoms
                mw > 200,                # Typical QM7 MW upper limit
                num_atoms > 30,          # Typical QM7 total atoms limit
                not set(atomic_symbols).issubset({'C', 'H', 'N', 'O', 'F'})  # Only allowed atoms
            ]),
            'QM7_score': sum([
                num_heavy_atoms <= 23,    # Within QM7 heavy atom limit
                mw <= 250,               # Within typical MW range
                num_atoms <= 30,         # Within total atoms limit
                'C' in atomic_symbols,   # Contains Carbon
                set(atomic_symbols).issubset({'C', 'H', 'N', 'O', 'F'})  # Only allowed atoms
            ]) / 5.0  # Normalized score (5 criteria)
        }

        # Add functional group counts to features
        features.update(fg_counts)

        # Debug print
#         print(f"Number of functional groups found: {len(fg_counts)}")
#         print(f"Total number of features: {len(features)}")

        return features

    def analyze(self):
        """Analyze fairness based on QM7 conditions"""
        try:
            print("\n=== Starting analyze method ===")

            # Collect predictions and data
            with torch.no_grad():
                predictions = []
                true_values = []
                features_list = []

                for batch in self.test_loader:
                    batch = batch.to(self.device)
                    outputs = self.model(batch)
                    predictions.extend(outputs.cpu().numpy().flatten())
                    true_values.extend(batch.y.cpu().numpy().flatten())

                    if hasattr(batch, '_store') and hasattr(batch._store, 'smiles'):
                        for smiles in batch._store.smiles:
                            feat = self.extract_molecular_features(smiles)
                            if feat:
                                features_list.append(feat)

            # Convert to DataFrame
            X_test = pd.DataFrame(features_list)
            y_test = np.array(true_values)
            predictions = np.array(predictions)

            # Create explainer
            wrapped_model = SimpleWrapper(self.model, self.device)
            wrapped_model.preprocessed_data = predictions

            explainer = dx.Explainer(
                model=wrapped_model,
                data=X_test,
                y=y_test,
                predict_function=wrapped_model.predict,
                model_type='regression',
                verbose=True
            )

            fairness_results = {}

            # Calculate fairness metrics for all molecular features
            features_to_analyze = {
                'MW': {'threshold': 500, 'direction': '>'},  # MW > 500
                'LogP': {'threshold': 3, 'direction': '>'},  # LogP > 3
                'TPSA': {'threshold': 75, 'direction': '<='},  # TPSA <= 75
                'RotBonds': {'threshold': 10, 'direction': '>'},  # RotBonds > 10
            }

            # Add functional groups
            for col in X_test.columns:
                if col.endswith('_count'):
                    # For functional groups, use median as threshold
                    features_to_analyze[col] = {
                        'threshold': X_test[col].median(),
                        'direction': '>'
                    }

            print("\nAnalyzing features:")
            for feature, params in features_to_analyze.items():
                try:
                    if feature in X_test.columns:
                        print(f"\nProcessing {feature}:")
                        threshold = params['threshold']
                        direction = params['direction']

                        # Create protected groups based on threshold and direction
                        if direction == '>':
                            protected = pd.Series('compliant', index=X_test.index)
                            protected[X_test[feature] > threshold] = 'non_compliant'
                        else:  # '<='
                            protected = pd.Series('compliant', index=X_test.index)
                            protected[X_test[feature] <= threshold] = 'non_compliant'

                        # Calculate fairness metrics
                        f_metrics = explainer.model_fairness(
                            protected=protected,
                            privileged='compliant'
                        )

                        print(f"{feature} groups distribution:")
                        print(protected.value_counts())
                        print("Percentage:", protected.value_counts(normalize=True).round(3) * 100)
                        print(f"{feature} fairness check:")
                        print("f_metrics.fairness_check() :",f_metrics.fairness_check())


                        fairness_results[feature] = f_metrics
                        print("fairness_results[feature]: ",fairness_results[feature])

                        # Plot fairness metrics
                        plt.figure(figsize=(10, 6))
                        f_metrics.plot()
                        f_metrics.plot(type='density')
                        plt.show()

                except Exception as e:
                    print(f"Error analyzing {feature}: {str(e)}")




            try:
                self.visualize_fairness_metrics(fairness_results)
            except Exception as e:
                print(f"Error in visualization: {str(e)}")

            return fairness_results

        except Exception as e:
            print(f"\nError in analyze method: {str(e)}")
            traceback.print_exc()
            return None

    def visualize_fairness_metrics(self, fairness_results):
        """Create heatmap visualization of fairness metrics"""
        try:
            # Create DataFrame to store metrics
            metrics_df = pd.DataFrame(columns=['Independence', 'Separation', 'Sufficiency'])

            print("\nDebugging fairness results:")
            print(f"Number of features to process: {len(fairness_results)}")

            for feature, f_obj in fairness_results.items():
                print(f"\n{'='*50}")
                print(f"Processing feature: {feature}")

                # Get clean feature name
                feature_name = feature.replace('_count', '').replace('_', ' ').title()
                print(f"Cleaned feature name: {feature_name}")

                # Access the metrics data directly
                if hasattr(f_obj, 'result'):
                    metrics_data = f_obj.result
                    print("\nMetrics data found:")
                    print(metrics_data)

                    if (isinstance(metrics_data, pd.DataFrame) and
                        not metrics_data.empty and
                        'non_compliant' in metrics_data.index):

                        metrics = metrics_data.loc['non_compliant']
                        print("\nExtracted metrics:")
                        print(f"Independence: {metrics['independence']:.4f}")
                        print(f"Separation: {metrics['separation']:.4f}")
                        print(f"Sufficiency: {metrics['sufficiency']:.4f}")

                        # Store metrics
                        metrics_df.loc[feature_name] = [
                            metrics['independence'],
                            metrics['separation'],
                            metrics['sufficiency']
                        ]
                    else:
                        print("No valid metrics in result")
                        metrics_df.loc[feature_name] = [0.0, 0.0, 0.0]
                else:
                    print("No result attribute found")
                    metrics_df.loc[feature_name] = [0.0, 0.0, 0.0]

#                 print("\nCurrent metrics DataFrame:")
#                 print(metrics_df)

            print("\n" + "="*50)
            print("Final Metrics DataFrame:")
            print(metrics_df)

            if not metrics_df.empty:
                # Sort by average deviation from 1.0 (only for non-zero values)
                metrics_df['avg_effect'] = abs(metrics_df - 1.0).mean(axis=1)
                metrics_df = metrics_df.sort_values('avg_effect', ascending=False)
                metrics_df = metrics_df.drop('avg_effect', axis=1)

                # Create heatmap
                plt.figure(figsize=(12, len(metrics_df) * 0.5))
                sns.heatmap(metrics_df,
                           annot=True,
                           cmap='RdYlBu',
                           center=1.0,
                           fmt='.2f',
                           vmin=0.5,
                           vmax=2.0)

                plt.title('QM9 Regression Fairness Metrics')
                plt.ylabel('Molecular Features')
                plt.xlabel('Fairness Metrics')

                cbar = plt.gca().collections[0].colorbar
                cbar.set_label('Metric Ratio (1.0 = Fair)', rotation=270, labelpad=15)

                plt.tight_layout()
                plt.show()

                # Print feature summaries
                print("\nFeatures with significant bias (outside 0.8-1.25 range):")
                bias_mask = (metrics_df > 1.25) | (metrics_df < 0.8)
                biased_features = metrics_df[bias_mask.any(axis=1) & (metrics_df != 0).any(axis=1)]
                print(biased_features.round(3))

        except Exception as e:
            print(f"Error in visualization: {str(e)}")
            traceback.print_exc()

In [None]:
def analyze_regression_fairness(model, test_loader, device):
    """Main function to perform regression fairness analysis"""
    try:
        print("\n=== Starting Regression Fairness Analysis ===")

        # Create analyzer and get fairness results
        analyzer = MolecularRegressionFairnessAnalyzer(model, test_loader, device)
        fairness_results = analyzer.analyze()

        if fairness_results:
            print("\nRegression Fairness Analysis Results:")

#             # Print detailed metrics
#             for criterion, f_object in fairness_results.items():
#                 print(f"\n{criterion.upper()} Fairness Metrics:")
#                 print(f_object.fairness_check())

            # Create heatmap visualization
            print("\nGenerating fairness metrics visualization...")
            analyzer.visualize_fairness_metrics(fairness_results)

#             # Print summary
#             print("\nFairness Analysis Summary:")
#             for criterion in fairness_results:
#                 print(f"\n{criterion.upper()} results available")

        else:
            print("\nNo fairness metrics were calculated successfully.")

        return fairness_results

    except Exception as e:
        print(f"\nError during fairness analysis: {str(e)}")
        traceback.print_exc()
        return None


def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load QM7 dataset
    print("Loading QM7 dataset...")
    tasks, datasets, transformers = load_qm7(
        featurizer='Raw',  # We'll use our own featurizer
        splitter='random',
        transformers=['balancing']
    )

    train_dataset, val_dataset, test_dataset = datasets

    print("Creating feature extractor...")
    feature_extractor = MolecularFeatureExtractor()

    print("Processing training set...")
    train_data = QM7Dataset(train_dataset.ids, train_dataset.y, feature_extractor)
    print("Processing validation set...")
    val_data = QM7Dataset(val_dataset.ids, val_dataset.y, feature_extractor)
    print("Processing test set...")
    test_data = QM7Dataset(test_dataset.ids, test_dataset.y, feature_extractor)

    # Create data loaders
    print("Creating data loaders...")
    batch_size = 32
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size)
    test_loader = DataLoader(test_data, batch_size=batch_size)


    # Load pretrained encoder
    print("Loading pretrained encoder...")
    encoder_path = '../checkpoints/encoders/final_encoder_20250216_111050.pt'
    encoder, model_info = load_pretrained_encoder(encoder_path, device)

    # Get dimensions from the processed data
    sample_data = next(iter(train_loader))
    node_dim = sample_data.x_cat.size(1) + sample_data.x_phys.size(1)
    edge_dim = sample_data.edge_attr.size(1)
    print(f"Initializing model with node_dim={node_dim}, edge_dim={edge_dim}")

    # Initialize and train model
    print("Initializing model...")
    model = HybridGNNRegressor(
        encoder=encoder,
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        num_layers=3,
        num_tasks=1,
        dropout=0.3
    ).to(device)

    # Train model
    best_model = train_hybrid_model_regression(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=100,
        lr=1e-3,
        weight_decay=1e-4
    )

    if best_model is not None:
        # Load best model for testing
        model.load_state_dict(best_model['state_dict'])

        # Evaluate on test set
        model.eval()
        y_true_test = []
        y_pred_test = []

        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                outputs = model(batch)
                targets = batch.y.view(-1, 1)

                y_true_test.extend(targets.cpu().numpy())
                y_pred_test.extend(outputs.cpu().numpy())  # Removed sigmoid for regression

        # Calculate regression metrics
        test_rmse = np.sqrt(mean_squared_error(y_true_test, y_pred_test))
        test_r2 = r2_score(y_true_test, y_pred_test)
        test_mae = mean_absolute_error(y_true_test, y_pred_test)

        print("\nTest Results:")
        print(f"Test RMSE: {test_rmse:.4f}")
        print(f"Test R²: {test_r2:.4f}")
        print(f"Test MAE: {test_mae:.4f}")

        fairness_results = analyze_regression_fairness(model, test_loader, device)

#         # Results will contain fairness objects for each criterion
#         if fairness_results:
#             print("\nFairness Analysis Summary:")
#             for criterion in fairness_results:
#                 print(f"\n{criterion.upper()} results available")

    return fairness_results

if __name__ == "__main__":
    fairness_results = main()