In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader, Batch  # Import Batch
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from rdkit import Chem
from rdkit.Chem import Descriptors
import os
import deepchem as dc
from scipy import stats
import dalex as dx

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class MolecularGraphDataset:
    def __init__(self, smiles_list, labels):
        self.smiles_list = smiles_list
        self.labels = labels
        self.graphs = []
        self.processed_labels = []
        self._process_smiles()
        
    def _process_smiles(self):
        for i, smiles in enumerate(self.smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
            
            try:
                Chem.Kekulize(mol, clearAromaticFlags=True)
                
                # Enhanced atom features
                atom_features_list = []
                for atom in mol.GetAtoms():
                    features = [
                        atom.GetAtomicNum(),
                        atom.GetDegree(),
                        atom.GetFormalCharge(),
                        int(atom.GetHybridization()),
                        int(atom.GetIsAromatic()),
                        int(atom.IsInRing()),
                        atom.GetNumRadicalElectrons(),
                        atom.GetTotalNumHs()
                    ]
                    atom_features_list.append(features)
                    
                x = torch.tensor(atom_features_list, dtype=torch.float)
                
                # Enhanced bond features
                edge_indices = []
                edge_features = []
                for bond in mol.GetBonds():
                    i = bond.GetBeginAtomIdx()
                    j = bond.GetEndAtomIdx()
                    
                    features = [
                        int(bond.GetBondType()),
                        int(bond.GetIsConjugated()),
                        int(bond.IsInRing()),
                        int(bond.GetStereo()),
                        bond.GetValenceContrib(bond.GetBeginAtom()),
                        bond.GetValenceContrib(bond.GetEndAtom())
                    ]
                    
                    edge_indices.extend([[i, j], [j, i]])
                    edge_features.extend([features, features])
                
                if len(edge_indices) == 0:
                    continue
                
                edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_features, dtype=torch.float)
                
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
                self.graphs.append(data)
                self.processed_labels.append(int(self.labels[i]))
                
            except Exception as e:
                continue
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return self.graphs[idx], self.processed_labels[idx]

class EnhancedContrastiveEncoder(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim=128, output_dim=128):
        super().__init__()
        
        # Enhanced node encoder with batch normalization
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )
        
        # Enhanced edge encoder
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim)
        )
        
        # Multiple graph convolution layers
        self.graph_conv1 = GCNConv(hidden_dim, hidden_dim)
        self.graph_conv2 = GCNConv(hidden_dim, hidden_dim)
        self.graph_conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        # Output projection
        self.output = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, output_dim)
        )
        
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.graph_attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
        
        # Add regularization
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x, edge_index, edge_attr, batch):
        # Node and edge encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions with residual connections
        x1 = F.relu(self.graph_conv1(x, edge_index))
        x2 = F.relu(self.graph_conv2(x1, edge_index)) + x1
        x3 = self.graph_conv3(x2, edge_index) + x2
        
        # Apply layer normalization
        x3 = self.layer_norm(x3)
        
        # Attention-weighted pooling
        attention_weights = self.attention(x3)
        attention_weights = F.softmax(attention_weights, dim=0)
        attended_features = torch.sum(x3 * attention_weights, dim=0)
        
        # Global mean pooling
        pooled_features = global_mean_pool(x3, batch)
        
        # Concatenate attended and pooled features
        combined_features = torch.cat([attended_features.unsqueeze(0).expand(pooled_features.size(0), -1), 
                                       pooled_features], dim=1)
        
        # Final output
        out = self.output(combined_features)
        return F.normalize(out, dim=1)

class MolecularClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims=(256, 128)):
        super().__init__()
        layers = []
        last_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))
            last_dim = hidden_dim
        layers.append(nn.Linear(last_dim, 1))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

def compute_fairness_loss(predictions, labels, sensitive_attributes):
    """
    Compute fairness loss based on demographic parity and equal opportunity.
    """
    epsilon = 1e-6
    privileged_mask = (sensitive_attributes == 1)
    unprivileged_mask = (sensitive_attributes == 0)
    
    # Demographic Parity Difference
    mean_pred_privileged = predictions[privileged_mask].mean()
    mean_pred_unprivileged = predictions[unprivileged_mask].mean()
    dp_difference = torch.abs(mean_pred_privileged - mean_pred_unprivileged)
    
    # Equal Opportunity Difference (True Positive Rates)
    true_labels_privileged = labels[privileged_mask]
    true_labels_unprivileged = labels[unprivileged_mask]
    
    predictions_privileged = predictions[privileged_mask]
    predictions_unprivileged = predictions[unprivileged_mask]
    
    tpr_privileged = ((predictions_privileged >= 0.5) & (true_labels_privileged == 1)).sum() / (true_labels_privileged == 1).sum().float().clamp(min=epsilon)
    tpr_unprivileged = ((predictions_unprivileged >= 0.5) & (true_labels_unprivileged == 1)).sum() / (true_labels_unprivileged == 1).sum().float().clamp(min=epsilon)
    eo_difference = torch.abs(tpr_privileged - tpr_unprivileged)
    
    fairness_loss = dp_difference + eo_difference
    return fairness_loss

def visualize_embedding_distribution(embeddings, labels, title="Embedding Distribution"):
    """Visualize the distribution of embeddings using t-SNE."""
    from sklearn.manifold import TSNE
    
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='viridis')
    plt.colorbar(scatter)
    plt.title(title)
    plt.xlabel("t-SNE dimension 1")
    plt.ylabel("t-SNE dimension 2")
    plt.show()

def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix"):
    """Plot confusion matrix with enhanced visualization."""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

def plot_roc_curve(y_true, y_proba, title="ROC Curve"):
    """Plot ROC curve with additional metrics."""
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

def create_molecular_groups_validated(test_dataset):
    """Enhanced molecular feature groups with additional chemical criteria."""
    smiles_list = test_dataset.smiles_list
    
    properties = {
        'MW': [], 'LogP': [], 'HBD': [], 'HBA': [], 'TPSA': [], 
        'RotBonds': [], 'AromaticRings': [], 'HeteroAtoms': []
    }
    
    thresholds = {
        'MW': 500, 'LogP': 5, 'HBD': 5, 'HBA': 10, 
        'TPSA': 90, 'RotBonds': 8
    }
    
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            properties['MW'].append(Descriptors.ExactMolWt(mol))
            properties['LogP'].append(Descriptors.MolLogP(mol))
            properties['HBD'].append(Descriptors.NumHDonors(mol))
            properties['HBA'].append(Descriptors.NumHAcceptors(mol))
            properties['TPSA'].append(Descriptors.TPSA(mol))
            properties['RotBonds'].append(Descriptors.NumRotatableBonds(mol))
            properties['AromaticRings'].append(len(Chem.GetSymmSSSR(mol)))
            properties['HeteroAtoms'].append(len([atom for atom in mol.GetAtoms() 
                                                if atom.GetAtomicNum() not in [6, 1]]))
        else:
            for key in properties:
                properties[key].append(None)
    
    df = pd.DataFrame(properties)
    
    # Enhanced feature groups
    feature_groups = {}
    
    # Ro5 compliance
    ro5_violations = np.zeros(len(df))
    for prop, threshold in thresholds.items():
        if prop in ['MW', 'LogP', 'HBD', 'HBA']:
            violations = (df[prop] > threshold).astype(int)
            ro5_violations += violations
    
    feature_groups['Ro5_compliant'] = (ro5_violations < 2).astype(int)
    
    # BBB penetration
    feature_groups['BBB_favorable'] = ((df['MW'] <= 400) & 
                                     (df['TPSA'] <= 90) & 
                                     (df['LogP'] <= 5) & 
                                     (df['HBD'] <= 3)).astype(int)
    
    # BACE binding
    feature_groups['BACE_favorable'] = ((df['MW'] >= 300) & 
                                      (df['MW'] <= 600) & 
                                      (df['LogP'] >= 1) & 
                                      (df['LogP'] <= 5)).astype(int)
    
    # Add custom groups
    for name, group in feature_groups.items():
        df[name] = group
    
    return df

def adapt_state_dict(checkpoint_state_dict, model):
    """Adapt the saved state dict to work with the enhanced model architecture."""
    new_state_dict = {}
    model_state_dict = model.state_dict()
    
    # Initialize new state dict with current model's structure
    for key in model_state_dict.keys():
        if key in checkpoint_state_dict:
            if model_state_dict[key].shape == checkpoint_state_dict[key].shape:
                new_state_dict[key] = checkpoint_state_dict[key]
            else:
                # Handle dimension mismatches
                if 'node_encoder.0.weight' in key:
                    # Adapt node encoder weights
                    old_weight = checkpoint_state_dict[key]
                    new_weight = torch.zeros(model_state_dict[key].shape)
                    min_dim = min(old_weight.shape[1], new_weight.shape[1])
                    new_weight[:, :min_dim] = old_weight[:, :min_dim]
                    new_state_dict[key] = new_weight
                elif 'edge_encoder.0.weight' in key:
                    # Adapt edge encoder weights
                    old_weight = checkpoint_state_dict[key]
                    new_weight = torch.zeros(model_state_dict[key].shape)
                    min_dim = min(old_weight.shape[1], new_weight.shape[1])
                    new_weight[:, :min_dim] = old_weight[:, :min_dim]
                    new_state_dict[key] = new_weight
                else:
                    # Initialize other mismatched layers with current model's values
                    new_state_dict[key] = model_state_dict[key]
        else:
            # Initialize new layers with current model's values
            new_state_dict[key] = model_state_dict[key]
    
    return new_state_dict

def main():
    # Previous parameters remain the same
    contrastive_model_path = 'molecular_gan_cl_models_CA11.pt'
    input_dim = 8  # Enhanced atom features
    edge_dim = 6   # Enhanced bond features
    hidden_dim = 128
    output_dim = 128
    
    # Initialize the enhanced encoder
    encoder_model = EnhancedContrastiveEncoder(
        node_dim=input_dim,
        edge_dim=edge_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim
    ).to(device)
    
    # Load and adapt the pre-trained weights
    try:
        checkpoint = torch.load(contrastive_model_path, map_location=device)
        adapted_state_dict = adapt_state_dict(checkpoint['encoder_state_dict'], encoder_model)
        encoder_model.load_state_dict(adapted_state_dict, strict=False)
        print("Successfully loaded and adapted pre-trained model")
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return

    # Modified initialization for batch normalization
    encoder_model.train()
    with torch.no_grad():
        # Create proper batch size for initialization
        batch_size = 32
        num_nodes = 16  # Multiple nodes per graph
        num_edges = 32  # Multiple edges per graph
        
        # Create dummy batch with multiple nodes per graph
        dummy_data = torch.randn(batch_size * num_nodes, input_dim).to(device)
        
        # Create edge indices that connect nodes within each graph
        edge_indices = []
        edge_attrs = []
        batch_assignments = []
        
        for i in range(batch_size):
            # Create edges for current graph
            start_idx = i * num_nodes
            for j in range(num_edges):
                src = start_idx + torch.randint(0, num_nodes, (1,)).item()
                dst = start_idx + torch.randint(0, num_nodes, (1,)).item()
                edge_indices.extend([[src, dst]])
                edge_attrs.append(torch.randn(edge_dim))
            
            # Batch assignments for current graph
            batch_assignments.extend([i] * num_nodes)
        
        dummy_edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous().to(device)
        dummy_edge_attr = torch.stack(edge_attrs).to(device)
        dummy_batch = torch.tensor(batch_assignments, dtype=torch.long).to(device)
        
        # Initialize BatchNorm statistics with proper batched data
        for _ in range(10):
            encoder_model(dummy_data, dummy_edge_index, dummy_edge_attr, dummy_batch)
    
    encoder_model.eval()
    
    print("\n=== BACE Classification Task ===")
    
    # Prepare data
    tasks, datasets, transformers = dc.molnet.load_bace_classification(
        featurizer='Raw',
        splitter='scaffold',
        transformers=[],
        reload=True
    )
    train_dataset, valid_dataset, test_dataset = datasets
    
    # Combine train and validation datasets for cross-validation
    combined_smiles = train_dataset.ids.tolist() + valid_dataset.ids.tolist()
    combined_labels = np.concatenate([train_dataset.y[:, 0], valid_dataset.y[:, 0]])
    
    # Process datasets
    combined_data = MolecularGraphDataset(
        combined_smiles,
        combined_labels.tolist()
    )
    
    test_data = MolecularGraphDataset(
        test_dataset.ids.tolist(),
        test_dataset.y[:, 0].tolist()
    )
    
    # Generate embeddings for test data
    test_loader = DataLoader([(data, label) for data, label in zip(test_data.graphs, test_data.processed_labels)], batch_size=32, shuffle=False)
    test_embeddings = []
    y_test = []
    
    with torch.no_grad():
        for batch_data in test_loader:
            batch_graphs, batch_labels = batch_data
            batch = Batch.from_data_list(batch_graphs).to(device)  # Use Batch.from_data_list
            embedding = encoder_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            test_embeddings.append(embedding.cpu().numpy())
            y_test.extend(batch_labels)
    
    X_test = np.vstack(test_embeddings)
    y_test = np.array(y_test)
    
    # Stratified K-Fold Cross-Validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_accuracies = []
    fold_auc_scores = []
    
    y_combined = np.array(combined_data.processed_labels)
    sensitive_attributes_df = create_molecular_groups_validated(combined_data)
    sensitive_attributes = sensitive_attributes_df['Ro5_compliant'].values  # Use appropriate length

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(y_combined)), y_combined)):
        print(f"\n=== Fold {fold + 1} ===")
        # Prepare train and validation data
        train_graphs = [combined_data.graphs[i] for i in train_idx]
        train_labels = [combined_data.processed_labels[i] for i in train_idx]
        val_graphs = [combined_data.graphs[i] for i in val_idx]
        val_labels = [combined_data.processed_labels[i] for i in val_idx]
        
        # Generate embeddings for training data
        train_loader = DataLoader([(data, label) for data, label in zip(train_graphs, train_labels)], batch_size=32, shuffle=True)
        val_loader = DataLoader([(data, label) for data, label in zip(val_graphs, val_labels)], batch_size=32, shuffle=False)
        
        # Initialize classifier
        classifier = MolecularClassifier(input_dim=output_dim).to(device)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
        
        # Get sensitive attributes for the current training batch
        train_sensitive_attributes = sensitive_attributes[train_idx]
        val_sensitive_attributes = sensitive_attributes[val_idx]
        
        # Training loop
        num_epochs = 20
        for epoch in range(num_epochs):
            classifier.train()
            epoch_loss = 0
            for batch_data in train_loader:
                batch_graphs, batch_labels = batch_data
                batch = Batch.from_data_list(batch_graphs).to(device)  # Use Batch.from_data_list
                embeddings = encoder_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                outputs = classifier(embeddings).squeeze()
                labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)
                
                # Fairness loss
                batch_indices = [train_idx[i] for i in range(len(batch_labels))]  # Get original indices
                batch_sensitive = torch.tensor(train_sensitive_attributes[batch_indices], dtype=torch.float32).to(device)
                fairness_loss = compute_fairness_loss(torch.sigmoid(outputs), labels, batch_sensitive)
                
                # Total loss
                loss = criterion(outputs, labels) + 0.1 * fairness_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}")
        
        # Evaluation on validation set
        classifier.eval()
        val_outputs = []
        val_labels_list = []
        with torch.no_grad():
            for batch_data in val_loader:
                batch_graphs, batch_labels = batch_data
                batch = Batch.from_data_list(batch_graphs).to(device)  # Use Batch.from_data_list
                embeddings = encoder_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                outputs = classifier(embeddings).squeeze()
                val_outputs.extend(torch.sigmoid(outputs).cpu().numpy())
                val_labels_list.extend(batch_labels)
        
        val_preds = (np.array(val_outputs) >= 0.5).astype(int)
        val_labels_array = np.array(val_labels_list)
        accuracy = accuracy_score(val_labels_array, val_preds)
        auc_score = roc_auc_score(val_labels_array, val_outputs)
        fold_accuracies.append(accuracy)
        fold_auc_scores.append(auc_score)
        print(f"Validation Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}")
    
    print("\n=== Cross-Validation Results ===")
    print(f"Average Accuracy: {np.mean(fold_accuracies):.4f}")
    print(f"Average AUC: {np.mean(fold_auc_scores):.4f}")
    
    # Retrain on full training data
    train_loader = DataLoader([(data, label) for data, label in zip(combined_data.graphs, combined_data.processed_labels)], batch_size=32, shuffle=True)
    classifier = MolecularClassifier(input_dim=output_dim).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # Training loop
    num_epochs = 20
    for epoch in range(num_epochs):
        classifier.train()
        epoch_loss = 0
        for batch_data in train_loader:
            batch_graphs, batch_labels = batch_data
            batch = Batch.from_data_list(batch_graphs).to(device)  # Use Batch.from_data_list
            embeddings = encoder_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            outputs = classifier(embeddings).squeeze()
            labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)
            
            # Fairness loss
            batch_indices = [i for i in range(len(batch_labels))]
            batch_sensitive = torch.tensor(sensitive_attributes[batch_indices], dtype=torch.float32).to(device)
            fairness_loss = compute_fairness_loss(torch.sigmoid(outputs), labels, batch_sensitive)
            
            # Total loss
            loss = criterion(outputs, labels) + 0.1 * fairness_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}")
    
    # Evaluation on test set
    classifier.eval()
    test_outputs = []
    with torch.no_grad():
        for batch_data in test_loader:
            batch_graphs, batch_labels = batch_data
            batch = Batch.from_data_list(batch_graphs).to(device)  # Use Batch.from_data_list
            embeddings = encoder_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            outputs = classifier(embeddings).squeeze()
            test_outputs.extend(torch.sigmoid(outputs).cpu().numpy())
    
    y_pred = (np.array(test_outputs) >= 0.5).astype(int)
    y_pred_proba = np.array(test_outputs)
    
    print("\n=== Test Set Evaluation ===")
    plot_confusion_matrix(y_test, y_pred, "BACE Classification Confusion Matrix")
    print("\n=== Classification Metrics ===")
    print(classification_report(y_test, y_pred))
    
    # Visualizations
    plot_roc_curve(y_test, y_pred_proba, "BACE Classification ROC Curve")
    visualize_embedding_distribution(X_test, y_test, "BACE Embedding Distribution")
    
    # Fairness analysis
    print("\n=== Fairness Analysis ===")
    # Since the classifier is now a PyTorch model, we need to adjust the fairness analysis
    # We can use the test outputs and the sensitive attributes from the test set
    test_sensitive_attributes_df = create_molecular_groups_validated(test_data)
    test_sensitive_attributes = test_sensitive_attributes_df['Ro5_compliant'].values[:len(y_test)]
    test_sensitive_attributes = torch.tensor(test_sensitive_attributes, dtype=torch.float32).to(device)
    test_labels = torch.tensor(y_test, dtype=torch.float32).to(device)
    test_predictions = torch.tensor(y_pred_proba, dtype=torch.float32).to(device)
    fairness_loss = compute_fairness_loss(test_predictions, test_labels, test_sensitive_attributes)
    print(f"Fairness Loss on Test Set: {fairness_loss.item():.4f}")
    
    # Additional fairness metrics can be calculated as needed

if __name__ == "__main__":
    main()






  "class": algorithms.Blowfish,


Using device: cuda
Successfully loaded and adapted pre-trained model

=== BACE Classification Task ===


AttributeError: 'tupleBatch' object has no attribute 'stores_as'