In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from abc import ABC, abstractmethod
from aif360.datasets import AdultDataset as Aif360AdultDataset

# Constants
RANDOM_SEED = 42
ADULT_DATA_FEATURES = 102

class AIF360TorchDataset(TorchDataset):
    """
    Wraps an AIF360 BinaryLabelDataset as a PyTorch Dataset.
    
    This wrapper converts AIF360's numpy-based datasets into PyTorch-compatible
    format while preserving all the fairness-related metadata like protected attributes.
    """
    
    def __init__(self, aif360_dataset, include_protected=True):
        """
        Args:
            aif360_dataset: an AIF360 dataset object (e.g. AdultDataset(), COMPASDataset(), etc.)
            include_protected: if True, will also expose protected_attributes
        """
        # Extract numpy arrays from AIF360
        X = aif360_dataset.features
        y = aif360_dataset.labels.ravel()
        
        # Handle protected attributes if requested
        if include_protected:
            prot = aif360_dataset.protected_attributes
            self.protected_attrs = torch.tensor(prot, dtype=torch.float32)
        else:
            self.protected_attrs = None
        
        # Convert to PyTorch tensors with appropriate data types
        self.features = torch.tensor(X, dtype=torch.float32)
        # CRITICAL FIX: Use float32 for labels since we're using BCELoss
        self.labels = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return self.features.shape[0]
    
    def __getitem__(self, idx):
        """
        Return features and labels in the format expected by DataLoader.
        For simplicity in training loop, we return a tuple (features, label)
        rather than a dictionary.
        """
        return self.features[idx], self.labels[idx]

class BaseDataset(ABC):
    """
    Abstract base class for dataset wrappers that provides a universal interface
    for loading, processing and accessing datasets in any format.
    """
    
    def __init__(self):
        """Initialize the dataset by loading the underlying AIF360 data."""
        self._aif360_dataset = self.load_data()
    
    @abstractmethod
    def load_data(self):
        """
        Loads the dataset from the AIF360 dataset object.
        Must be implemented by subclasses.
        
        Returns:
            AIF360 dataset object
        """
        raise NotImplementedError("load_data() must be implemented in subclasses")
    
    def to_torch(self, include_protected=True):
        """
        Converts the AIF360 dataset to PyTorch Datasets with train/val/test splits.
        
        Args:
            include_protected: if True, will also expose protected_attributes
            
        Returns:
            Tuple of (train_dataset, val_dataset, test_dataset)
        """
        converted_dataset = AIF360TorchDataset(self._aif360_dataset, include_protected)
        
        # Calculate split sizes
        total_size = len(converted_dataset)
        train_size = int(0.7 * total_size)
        val_size = int(0.15 * total_size)
        test_size = total_size - train_size - val_size
        
        # Perform the split with fixed random seed for reproducibility
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            converted_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(RANDOM_SEED)
        )
        
        return train_dataset, val_dataset, test_dataset

class AdultDataset(BaseDataset):  # CRITICAL FIX: Inherit from BaseDataset
    """
    Adult dataset class for loading and preprocessing Adult dataset.
    
    This class handles the specific configuration needed for the Adult/Census Income
    dataset, including proper handling of categorical features and protected attributes.
    """
    
    def __init__(self):
        super().__init__()  # This calls BaseDataset.__init__() which calls load_data()
    
    def load_data(self):
        """
        Load and configure the Adult dataset from AIF360.
        
        The configuration here follows fairness research best practices:
        - 'sex' is the primary protected attribute for fairness evaluation
        - Males are considered the privileged group
        - Categorical features are properly encoded
        - Missing values (marked as '?') are handled by dropping rows
        """
        adult_ds = Aif360AdultDataset(
            protected_attribute_names=['sex'],  # Primary protected attribute
            privileged_classes=[['Male']],      # Privileged group definition
            categorical_features=['workclass', 'education', 'marital-status',
                                'occupation', 'relationship', 'race', 'native-country'],
            features_to_keep=['age', 'workclass', 'education', 'education-num',
                            'marital-status', 'occupation', 'relationship', 'race',
                            'sex', 'capital-gain', 'capital-loss', 'hours-per-week',
                            'native-country'],
            na_values=['?'],  # Handle missing values
            custom_preprocessing=lambda df: df.dropna()  # Simple approach: drop missing values
        )
        
        # Verify the expected feature dimensionality
        assert adult_ds.features.shape[1] == ADULT_DATA_FEATURES, \
            f"Expected {ADULT_DATA_FEATURES} features, got {adult_ds.features.shape[1]}"
        
        return adult_ds

class MLP(nn.Module):
    """
    Multi-Layer Perceptron for classification tasks.
    
    This implementation is based on the Fair-Fairness Benchmark repository
    and includes some modifications for better integration with our training framework.
    """
    
    def __init__(self, n_features, mlp_layers=[512, 256, 64], p_dropout=0.2, num_classes=1):
        super(MLP, self).__init__()
        self.num_classes = num_classes
        self.mlp_layers = [n_features] + mlp_layers
        self.p_dropout = p_dropout
        
        # Create the hidden layers
        self.network = nn.ModuleList([
            nn.Linear(i, o) for i, o in zip(self.mlp_layers[:-1], self.mlp_layers[1:])
        ])
        
        # Final classification head
        self.head = nn.Linear(self.mlp_layers[-1], num_classes)
    
    def forward(self, x):
        """
        Forward pass through the network.
        
        Returns both hidden representation and final prediction to allow for
        analysis of intermediate representations (useful for fairness research).
        """
        # Pass through hidden layers with ReLU activation and dropout
        for layer in self.network:
            x = layer(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.p_dropout, training=self.training)
        
        # Store hidden representation (useful for fairness analysis)
        h = x
        
        # Final classification layer with sigmoid activation
        x = self.head(x)
        logits = torch.sigmoid(x)
        
        return h, logits

class SGDMechanism:
    """
    A mechanism that trains a model using SGD, designed for fairness research.
    
    This implementation includes proper early stopping, comprehensive logging,
    and handles the dual-output nature of our MLP model.
    """
    
    def __init__(self):
        pass
    
    def train(self, model, dataset: BaseDataset, **kwargs):
        """
        Train the model on the dataset using SGD.
        
        Args:
            model: The untrained model to be trained
            dataset: The dataset to train on (must be a BaseDataset subclass)
            **kwargs: Additional hyperparameters for training
            
        Returns:
            Dictionary containing training history and metrics
        """
        # Set up device (GPU if available, otherwise CPU)
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device = "cpu"
        model.to(device)
        print(f"Training on device: {device}")
        
        # Extract hyperparameters with sensible defaults
        num_epochs = kwargs.get('num_epochs', 100)
        learning_rate = kwargs.get('learning_rate', 0.01)
        batch_size = kwargs.get('batch_size', 32)
        patience = kwargs.get('patience', 10)  # For early stopping
        
        print(f"Training configuration:")
        print(f"  Epochs: {num_epochs}")
        print(f"  Learning rate: {learning_rate}")
        print(f"  Batch size: {batch_size}")
        print(f"  Early stopping patience: {patience}")
        
        # Create data loaders
        train_dataset, val_dataset, test_dataset = dataset.to_torch(include_protected=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        print(f"Dataset splits:")
        print(f"  Training samples: {len(train_dataset)}")
        print(f"  Validation samples: {len(val_dataset)}")
        print(f"  Test samples: {len(test_dataset)}")
        
        # Set up optimizer and loss function
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
        criterion = nn.BCELoss()  # Binary Cross-Entropy for binary classification
        
        # Training history tracking
        train_losses = []
        train_accuracies = []
        val_losses = []
        val_accuracies = []
        
        # Early stopping variables
        best_val_loss = float('inf')
        patience_counter = 0
        best_model_state = None
        
        print("\nStarting training...")
        
        for epoch in range(num_epochs):
            # ===== TRAINING PHASE =====
            model.train()  # Enable dropout and batch norm training mode
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            for batch_X, batch_y in train_loader:
                # Move data to device
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                
                # Ensure labels have correct shape for BCELoss
                batch_y = batch_y.view(-1, 1)
                
                # Zero gradients
                optimizer.zero_grad()
                
                # Forward pass - IMPORTANT: handle dual output from MLP
                hidden_repr, outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                
                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                
                # Accumulate training statistics
                train_loss += loss.item()
                predicted = (outputs > 0.5).float()
                train_total += batch_y.size(0)
                train_correct += (predicted == batch_y).sum().item()
            
            # ===== VALIDATION PHASE =====
            model.eval()  # Disable dropout and set batch norm to eval mode
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():  # Disable gradient computation for efficiency
                for batch_X, batch_y in val_loader:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                    batch_y = batch_y.view(-1, 1)
                    
                    # Forward pass
                    hidden_repr, outputs = model(batch_X)
                    loss = criterion(outputs, batch_y)
                    
                    # Accumulate validation statistics
                    val_loss += loss.item()
                    predicted = (outputs > 0.5).float()
                    val_total += batch_y.size(0)
                    val_correct += (predicted == batch_y).sum().item()
            
            # Calculate epoch metrics
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            train_acc = train_correct / train_total
            val_acc = val_correct / val_total
            
            # Store history
            train_losses.append(avg_train_loss)
            train_accuracies.append(train_acc)
            val_losses.append(avg_val_loss)
            val_accuracies.append(val_acc)
            
            # Print progress
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f'Epoch [{epoch+1:3d}/{num_epochs}] | '
                      f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f} | '
                      f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}')
            
            # CRITICAL FIX: Implement early stopping logic
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
                # Save the best model state
                best_model_state = model.state_dict().copy()
                print(f'New best validation loss: {best_val_loss:.4f}')
            else:
                patience_counter += 1
            
            # Check if we should stop early
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch + 1} epochs')
                print(f'Best validation loss: {best_val_loss:.4f}')
                break
        
        # Restore the best model
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
            print("Restored best model weights")
        
        print("Training completed!")
        
        return {
            'model': model,
            'train_losses': train_losses,
            'train_accuracies': train_accuracies,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'best_val_loss': best_val_loss
        }

# ===== USAGE EXAMPLE =====
def main():
    """
    Demonstrates how to use the corrected framework to train an MLP on the Adult dataset.
    """
    
    print("=== Fair MLP Training on Adult Dataset ===\n")
    
    # 1. Load the dataset
    print("Loading Adult dataset...")
    adult_dataset = AdultDataset()
    print(f"Dataset loaded with {adult_dataset._aif360_dataset.features.shape[0]} samples")
    print(f"Feature dimensionality: {adult_dataset._aif360_dataset.features.shape[1]}")
    
    # 2. Create the model
    print("\nCreating MLP model...")
    model = MLP(
        n_features=ADULT_DATA_FEATURES,
        mlp_layers=[256, 256],  # Following Fair-Fairness Benchmark
        p_dropout=0.2,
        num_classes=1
    )
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model created with {total_params:,} trainable parameters")
    
    # 3. Create the training mechanism
    print("\nInitializing SGD training mechanism...")
    sgd_mechanism = SGDMechanism()
    
    # 4. Train the model
    print("\n" + "="*50)
    training_results = sgd_mechanism.train(
        model=model,
        dataset=adult_dataset,
        num_epochs=100,
        learning_rate=0.01,
        batch_size=32,
        patience=50
    )
    
    # 5. Display results
    print("\n" + "="*50)
    print("TRAINING RESULTS SUMMARY")
    print("="*50)
    
    final_train_acc = training_results['train_accuracies'][-1]
    final_val_acc = training_results['val_accuracies'][-1]
    best_val_loss = training_results['best_val_loss']
    
    print(f"Final Training Accuracy: {final_train_acc:.4f} ({final_train_acc*100:.2f}%)")
    print(f"Final Validation Accuracy: {final_val_acc:.4f} ({final_val_acc*100:.2f}%)")
    print(f"Best Validation Loss: {best_val_loss:.4f}")
    
    epochs_trained = len(training_results['train_losses'])
    print(f"Training completed in {epochs_trained} epochs")
    
    # The trained model is now ready for fairness evaluation
    trained_model = training_results['model']
    print(f"\nTrained model ready for fairness evaluation!")
    
    return training_results

if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(RANDOM_SEED)
    import numpy as np
    np.random.seed(RANDOM_SEED)
    
    results = main()

=== Fair MLP Training on Adult Dataset ===

Loading Adult dataset...
Dataset loaded with 45222 samples
Feature dimensionality: 102

Creating MLP model...
Model created with 92,417 trainable parameters

Initializing SGD training mechanism...

Training on device: cpu
Training configuration:
  Epochs: 100
  Learning rate: 0.01
  Batch size: 32
  Early stopping patience: 50
Dataset splits:
  Training samples: 31655
  Validation samples: 6783
  Test samples: 6784

Starting training...
Epoch [  1/100] | Train Loss: 5.9257, Train Acc: 0.7692 | Val Loss: 5.9788, Val Acc: 0.7708
New best validation loss: 5.9788
New best validation loss: 5.9780
New best validation loss: 5.9772
New best validation loss: 5.9769
New best validation loss: 5.9762
Epoch [ 10/100] | Train Loss: 5.8841, Train Acc: 0.7751 | Val Loss: 5.9756, Val Acc: 0.7708
New best validation loss: 5.9756
New best validation loss: 5.9746
New best validation loss: 5.9730
Epoch [ 20/100] | Train Loss: 5.8863, Train Acc: 0.7751 | Val Loss:

KeyboardInterrupt: 