# Fairness-Aware Training Integration Guide

This notebook shows you **exactly** how to integrate fairness-aware training into your distillation pipeline.

## üéØ Your Situation
Based on your analysis:
- **Current fairness**: 1.10x ratio (excellent!)
- **Distillation impact**: Slightly improves fairness
- **Recommendation**: No immediate intervention needed, but this guide shows you how to add fairness-aware training for future experiments

## What You'll Learn
1. How to modify your training code for fairness
2. Different fairness loss functions and when to use them
3. How to monitor fairness during training
4. Code examples you can copy-paste

## Step 1: Setup and Imports

In [7]:
# Essential imports for fairness-aware training
import sys
sys.path.append('../')  # Add parent directory to path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Import your existing training modules
# from models.time_llm import TimeLLM  # Adjust path as needed
# from distillation.distillation_trainer import DistillationTrainer  # Adjust path as needed

# Import our fairness analyzer
from gender_fairness_analyzer import GenderFairnessAnalyzer

print("‚úÖ All imports loaded successfully!")
print("Ready for fairness-aware training integration")

‚úÖ All imports loaded successfully!
Ready for fairness-aware training integration


## Step 2: Fairness Loss Functions

Here are practical loss functions you can add to your training loop to ensure fair predictions across gender groups:

In [8]:
class FairnessLosses(nn.Module):
    """Collection of fairness loss functions for training"""
    
    def __init__(self):
        super().__init__()
    
    def demographic_parity_loss(self, predictions, gender_labels, lambda_fairness=0.1):
        """
        Enforces similar prediction distributions across gender groups
        Lower is more fair (0 = perfect parity)
        """
        male_mask = (gender_labels == 1)  # Assuming 1 = male, 0 = female
        female_mask = (gender_labels == 0)
        
        if male_mask.sum() == 0 or female_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device)
        
        male_mean = predictions[male_mask].mean()
        female_mean = predictions[female_mask].mean()
        
        # Penalize difference in means
        parity_loss = torch.abs(male_mean - female_mean)
        return lambda_fairness * parity_loss
    
    def equalized_odds_loss(self, predictions, targets, gender_labels, lambda_fairness=0.1):
        """
        Enforces similar true/false positive rates across groups
        """
        male_mask = (gender_labels == 1)
        female_mask = (gender_labels == 0)
        
        if male_mask.sum() == 0 or female_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device)
        
        # For regression: use threshold to create binary classification
        threshold = targets.mean()
        pred_binary = (predictions > threshold).float()
        target_binary = (targets > threshold).float()
        
        # True positive rates
        male_tpr = (pred_binary[male_mask] * target_binary[male_mask]).sum() / (target_binary[male_mask].sum() + 1e-8)
        female_tpr = (pred_binary[female_mask] * target_binary[female_mask]).sum() / (target_binary[female_mask].sum() + 1e-8)
        
        # False positive rates  
        male_fpr = (pred_binary[male_mask] * (1-target_binary[male_mask])).sum() / ((1-target_binary[male_mask]).sum() + 1e-8)
        female_fpr = (pred_binary[female_mask] * (1-target_binary[female_mask])).sum() / ((1-target_binary[female_mask]).sum() + 1e-8)
        
        tpr_diff = torch.abs(male_tpr - female_tpr)
        fpr_diff = torch.abs(male_fpr - female_fpr)
        
        return lambda_fairness * (tpr_diff + fpr_diff)
    
    def group_regularization_loss(self, predictions, gender_labels, lambda_fairness=0.1):
        """
        Minimizes variance in performance across groups
        """
        male_mask = (gender_labels == 1)
        female_mask = (gender_labels == 0)
        
        if male_mask.sum() == 0 or female_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device)
        
        male_var = predictions[male_mask].var()
        female_var = predictions[female_mask].var()
        
        # Encourage similar variance across groups
        variance_diff = torch.abs(male_var - female_var)
        return lambda_fairness * variance_diff

# Initialize fairness losses
fairness_losses = FairnessLosses()

print("‚úÖ Fairness loss functions defined!")
print("Available losses:")
print("- demographic_parity_loss: Ensures similar prediction distributions")  
print("- equalized_odds_loss: Ensures similar accuracy across groups")
print("- group_regularization_loss: Minimizes performance variance")

‚úÖ Fairness loss functions defined!
Available losses:
- demographic_parity_loss: Ensures similar prediction distributions
- equalized_odds_loss: Ensures similar accuracy across groups
- group_regularization_loss: Minimizes performance variance


## Step 3: Modified Training Loop

Here's how to integrate fairness losses into your existing training loop:

In [9]:
def fairness_aware_training_step(model, batch, optimizer, fairness_losses, lambda_fairness=0.1):
    """
    Modified training step that includes fairness constraints
    
    Args:
        model: Your TimeLLM or distillation model
        batch: Training batch (should include gender labels)
        optimizer: Your optimizer
        fairness_losses: FairnessLosses instance
        lambda_fairness: Weight for fairness vs accuracy tradeoff
    """
    
    # Extract batch components (adjust based on your data structure)
    inputs = batch['inputs']  # Time series data
    targets = batch['targets']  # Prediction targets
    gender_labels = batch['gender']  # Gender labels (0=female, 1=male)
    
    # Forward pass
    model.train()
    predictions = model(inputs)
    
    # Primary loss (your existing loss function)
    primary_loss = F.mse_loss(predictions, targets)  # Adjust based on your task
    
    # Fairness losses
    dp_loss = fairness_losses.demographic_parity_loss(predictions, gender_labels, lambda_fairness)
    eo_loss = fairness_losses.equalized_odds_loss(predictions, targets, gender_labels, lambda_fairness)
    gr_loss = fairness_losses.group_regularization_loss(predictions, gender_labels, lambda_fairness)
    
    # Total loss with fairness constraints
    total_loss = primary_loss + dp_loss + eo_loss + gr_loss
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # Return detailed loss breakdown for monitoring
    return {
        'total_loss': total_loss.item(),
        'primary_loss': primary_loss.item(),
        'demographic_parity_loss': dp_loss.item(),
        'equalized_odds_loss': eo_loss.item(),
        'group_regularization_loss': gr_loss.item()
    }

print("‚úÖ Fairness-aware training step defined!")
print("This function integrates fairness constraints into your training loop")
print("Use lambda_fairness to control accuracy vs fairness tradeoff")

‚úÖ Fairness-aware training step defined!
This function integrates fairness constraints into your training loop
Use lambda_fairness to control accuracy vs fairness tradeoff


## Step 4: Data Loader with Gender Labels

Your data loader needs to include gender information. Here's how to modify it:

In [10]:
class FairnessAwareDataset(torch.utils.data.Dataset):
    """
    Dataset wrapper that includes gender labels for fairness training
    """
    
    def __init__(self, data_path, patient_info_path):
        """
        Args:
            data_path: Path to your time series data
            patient_info_path: Path to patient demographics (should include gender)
        """
        # Load your existing data
        self.data = self.load_time_series_data(data_path)
        
        # Load patient demographics
        self.patient_info = pd.read_csv(patient_info_path)
        
        # Create gender mapping (adjust column names as needed)
        self.gender_map = {}
        for _, row in self.patient_info.iterrows():
            patient_id = row['patient_id']  # Adjust column name
            gender = 1 if row['gender'].lower() in ['m', 'male'] else 0  # Adjust column name
            self.gender_map[patient_id] = gender
    
    def load_time_series_data(self, data_path):
        """Load your time series data - implement based on your format"""
        # This is a placeholder - replace with your actual data loading logic
        return torch.randn(1000, 100, 1)  # Example: 1000 samples, 100 timesteps, 1 feature
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get time series data
        inputs = self.data[idx]
        targets = self.get_targets(idx)  # Implement based on your task
        
        # Get patient ID and corresponding gender
        patient_id = self.get_patient_id(idx)  # Implement based on your data structure
        gender = self.gender_map.get(patient_id, 0)  # Default to female if unknown
        
        return {
            'inputs': inputs,
            'targets': targets,
            'gender': torch.tensor(gender, dtype=torch.float32),
            'patient_id': patient_id
        }
    
    def get_targets(self, idx):
        """Implement based on your prediction task"""
        # Placeholder - replace with actual target extraction
        return torch.randn(1)
    
    def get_patient_id(self, idx):
        """Extract patient ID from your data structure"""
        # Placeholder - implement based on how you store patient IDs
        return f"patient_{idx % 100}"  # Example mapping

# Example usage
def create_fairness_aware_dataloader(data_path, patient_info_path, batch_size=32):
    """
    Create a DataLoader that includes gender information
    """
    dataset = FairnessAwareDataset(data_path, patient_info_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

print("‚úÖ Fairness-aware dataset class defined!")
print("Modify the load_time_series_data, get_targets, and get_patient_id methods")
print("to match your specific data format and structure")

‚úÖ Fairness-aware dataset class defined!
Modify the load_time_series_data, get_targets, and get_patient_id methods
to match your specific data format and structure


## Step 5: Complete Training Integration

Put it all together - here's a complete training function with fairness monitoring:

In [11]:
def train_with_fairness_monitoring(model, train_loader, val_loader, num_epochs=10, lambda_fairness=0.1):
    """
    Complete training function with fairness monitoring and checkpointing
    """
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    fairness_losses = FairnessLosses()
    analyzer = GenderFairnessAnalyzer()
    
    # Tracking
    train_history = defaultdict(list)
    fairness_history = []
    
    best_fairness_ratio = float('inf')
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        epoch_losses = defaultdict(list)
        
        for batch_idx, batch in enumerate(train_loader):
            # Fairness-aware training step
            loss_dict = fairness_aware_training_step(
                model, batch, optimizer, fairness_losses, lambda_fairness
            )
            
            # Track losses
            for key, value in loss_dict.items():
                epoch_losses[key].append(value)
            
            if batch_idx % 50 == 0:  # Print every 50 batches
                print(f"  Batch {batch_idx}: Total Loss = {loss_dict['total_loss']:.4f}")
        
        # Calculate epoch averages
        for key, values in epoch_losses.items():
            avg_value = np.mean(values)
            train_history[key].append(avg_value)
            print(f"  Avg {key}: {avg_value:.4f}")
        
        # Validation and fairness evaluation
        if val_loader is not None:
            fairness_metrics = evaluate_fairness(model, val_loader, analyzer)
            fairness_history.append(fairness_metrics)
            
            current_fairness_ratio = fairness_metrics['fairness_ratio']
            print(f"  Fairness Ratio: {current_fairness_ratio:.3f}")
            print(f"  Fairness Level: {fairness_metrics['fairness_level']}")
            
            # Save best model based on fairness
            if current_fairness_ratio < best_fairness_ratio and current_fairness_ratio >= 1.0:
                best_fairness_ratio = current_fairness_ratio
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'fairness_ratio': current_fairness_ratio,
                    'fairness_metrics': fairness_metrics
                }, 'best_fair_model.pth')
                print(f"  ‚úÖ New best fair model saved! Ratio: {current_fairness_ratio:.3f}")
    
    return {
        'train_history': train_history,
        'fairness_history': fairness_history,
        'best_fairness_ratio': best_fairness_ratio
    }

def evaluate_fairness(model, val_loader, analyzer):
    """
    Evaluate fairness metrics on validation set
    """
    model.eval()
    
    all_predictions = []
    all_genders = []
    all_targets = []
    
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['inputs']
            targets = batch['targets']
            genders = batch['gender']
            
            predictions = model(inputs)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_genders.extend(genders.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    # Calculate group-specific performance
    male_mask = np.array(all_genders) == 1
    female_mask = np.array(all_genders) == 0
    
    male_mse = np.mean((np.array(all_predictions)[male_mask] - np.array(all_targets)[male_mask]) ** 2)
    female_mse = np.mean((np.array(all_predictions)[female_mask] - np.array(all_targets)[female_mask]) ** 2)
    
    # Calculate fairness ratio (should be close to 1.0)
    fairness_ratio = max(male_mse, female_mse) / min(male_mse, female_mse)
    
    # Determine fairness level
    if fairness_ratio <= 1.10:
        fairness_level = "Excellent"
    elif fairness_ratio <= 1.25:
        fairness_level = "Good"
    elif fairness_ratio <= 1.50:
        fairness_level = "Acceptable"
    else:
        fairness_level = "Poor"
    
    return {
        'male_mse': male_mse,
        'female_mse': female_mse,
        'fairness_ratio': fairness_ratio,
        'fairness_level': fairness_level,
        'male_count': male_mask.sum(),
        'female_count': female_mask.sum()
    }

print("‚úÖ Complete fairness-aware training system defined!")
print("Key features:")
print("- Integrates fairness losses into training")
print("- Monitors fairness metrics during validation") 
print("- Saves best model based on fairness criteria")
print("- Provides detailed loss and fairness tracking")

‚úÖ Complete fairness-aware training system defined!
Key features:
- Integrates fairness losses into training
- Monitors fairness metrics during validation
- Saves best model based on fairness criteria
- Provides detailed loss and fairness tracking


## Step 6: Usage Example

Here's how to use the fairness-aware training in practice:

In [12]:
# Example: How to integrate with your existing training pipeline

# 1. Create fairness-aware data loaders
train_loader = create_fairness_aware_dataloader(
    data_path="path/to/your/train_data",
    patient_info_path="path/to/patient_demographics.csv",
    batch_size=32
)

val_loader = create_fairness_aware_dataloader(
    data_path="path/to/your/val_data", 
    patient_info_path="path/to/patient_demographics.csv",
    batch_size=32
)

# 2. Initialize your model (use your existing model)
# model = TimeLLM(config)  # Your existing model initialization
# model = DistillationModel(teacher, student)  # Or distillation setup

# 3. Train with fairness constraints
# results = train_with_fairness_monitoring(
#     model=model,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     num_epochs=50,
#     lambda_fairness=0.1  # Adjust this to balance fairness vs accuracy
# )

# 4. Plot training progress
# def plot_fairness_training_progress(results):
#     fig, axes = plt.subplots(2, 2, figsize=(15, 10))
#     
#     # Loss evolution
#     axes[0,0].plot(results['train_history']['total_loss'], label='Total Loss')
#     axes[0,0].plot(results['train_history']['primary_loss'], label='Primary Loss')
#     axes[0,0].set_title('Training Loss Evolution')
#     axes[0,0].legend()
#     
#     # Fairness losses
#     axes[0,1].plot(results['train_history']['demographic_parity_loss'], label='Demographic Parity')
#     axes[0,1].plot(results['train_history']['equalized_odds_loss'], label='Equalized Odds')
#     axes[0,1].plot(results['train_history']['group_regularization_loss'], label='Group Regularization')
#     axes[0,1].set_title('Fairness Losses')
#     axes[0,1].legend()
#     
#     # Fairness ratio evolution
#     fairness_ratios = [f['fairness_ratio'] for f in results['fairness_history']]
#     axes[1,0].plot(fairness_ratios)
#     axes[1,0].axhline(y=1.10, color='green', linestyle='--', label='Excellent (‚â§1.10)')
#     axes[1,0].axhline(y=1.25, color='yellow', linestyle='--', label='Good (‚â§1.25)')
#     axes[1,0].set_title('Fairness Ratio Evolution')
#     axes[1,0].legend()
#     
#     # Group performance
#     male_mses = [f['male_mse'] for f in results['fairness_history']]
#     female_mses = [f['female_mse'] for f in results['fairness_history']]
#     axes[1,1].plot(male_mses, label='Male MSE')
#     axes[1,1].plot(female_mses, label='Female MSE')
#     axes[1,1].set_title('Group-specific Performance')
#     axes[1,1].legend()
#     
#     plt.tight_layout()
#     plt.show()

print("üìã Usage Example Defined!")
print("\nüîß Integration Steps:")
print("1. Replace placeholder paths with your actual data paths")
print("2. Replace model initialization with your actual model") 
print("3. Uncomment and run the training code")
print("4. Use plot_fairness_training_progress() to visualize results")
print("\n‚öñÔ∏è Tuning Tips:")
print("- Start with lambda_fairness=0.05 for subtle fairness constraints")
print("- Increase to 0.1-0.2 for stronger fairness enforcement") 
print("- Monitor both accuracy and fairness metrics during training")

FileNotFoundError: [Errno 2] No such file or directory: 'path/to/patient_demographics.csv'

## Step 7: Quick Test with Your Current Setup

Let's test the fairness analyzer with your current distillation results:

In [None]:
# Test the fairness analyzer on your current results
import os

# Check current fairness level of your distillation experiments
analyzer = GenderFairnessAnalyzer()

# Look for the latest experiment
distillation_dir = "/workspace/LLM-TIME/distillation_experiments"
if os.path.exists(distillation_dir):
    print("üîç Analyzing current fairness level...")
    
    try:
        # Run the analyzer
        results = analyzer.analyze_latest_experiment()
        
        print(f"\nüìä Current Fairness Analysis:")
        print(f"üöπ Male patients: {results['male_performance']['count']}")
        print(f"üö∫ Female patients: {results['female_performance']['count']}")
        print(f"‚öñÔ∏è  Fairness ratio: {results['fairness_ratio']:.3f}")
        print(f"üéØ Fairness level: {results['fairness_level']}")
        
        if results['fairness_ratio'] <= 1.10:
            print("‚úÖ Great! Your current model already shows excellent fairness")
            print("   The integration guide above can help maintain this during future training")
        elif results['fairness_ratio'] <= 1.25:
            print("üëç Good fairness level - slight improvements possible")
            print("   Consider using lambda_fairness=0.05 for gentle fairness constraints")
        else:
            print("‚ö†Ô∏è  Fairness could be improved")
            print("   Recommend using lambda_fairness=0.1-0.2 for stronger fairness constraints")
            
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not analyze current results: {e}")
        print("   Make sure you have run some distillation experiments first")
        
else:
    print("üìù No distillation experiments found yet")
    print("   Run some experiments, then use this guide to make them fairer!")

print("\nüéØ Next Steps:")
print("1. If fairness is already good: Use this guide to maintain it during training")
print("2. If fairness needs improvement: Integrate the fairness losses above") 
print("3. Always monitor both accuracy AND fairness during training")

## Summary & Next Steps

üéâ **Congratulations!** You now have a complete fairness-aware training system.

### What You've Learned:
- **Fairness Loss Functions**: Demographic parity, equalized odds, and group regularization
- **Modified Training Loop**: Integrates fairness constraints with your existing training
- **Monitoring System**: Tracks both accuracy and fairness metrics during training  
- **Best Model Selection**: Saves models based on fairness criteria

### Integration Checklist:
- [ ] Modify your dataset class to include gender labels
- [ ] Add fairness losses to your training loop
- [ ] Set appropriate `lambda_fairness` values (start with 0.05-0.1)
- [ ] Monitor fairness metrics during validation
- [ ] Save best models based on fairness criteria

### Recommended Workflow:
1. **Baseline**: Train your model normally and measure fairness
2. **Integration**: Add fairness constraints with low Œª (0.05)
3. **Tuning**: Gradually increase Œª until you achieve desired fairness
4. **Monitoring**: Always validate on both accuracy and fairness metrics

### Fairness Thresholds:
- **Excellent**: Ratio ‚â§ 1.10 (‚â§10% difference between groups)  
- **Good**: Ratio ‚â§ 1.25 (‚â§25% difference)
- **Acceptable**: Ratio ‚â§ 1.50 (‚â§50% difference)

### Support Resources:
- `gender_fairness_analyzer.py` - Analyze existing results
- `Gender_Fairness_Analysis.ipynb` - Interactive fairness analysis
- `README.md` - Complete documentation and guides

**Happy Fair Training! ‚öñÔ∏èüöÄ**

## Extended Fairness Analysis - Multiple Demographics

Want to check fairness across other features like age, race, or disease severity? Here's how to extend the framework:

In [None]:
class MultiFairnessAnalyzer:
    """
    Analyze fairness across multiple demographic features
    """
    
    def __init__(self):
        self.supported_features = {
            'gender': {'male': 1, 'female': 0, 'm': 1, 'f': 0},
            'age_group': {'young': 0, 'middle': 1, 'old': 2},  # You can define age ranges
            'race': {'white': 0, 'black': 1, 'hispanic': 2, 'asian': 3, 'other': 4},
            'disease_severity': {'mild': 0, 'moderate': 1, 'severe': 2},
            'bmi_category': {'underweight': 0, 'normal': 1, 'overweight': 2, 'obese': 3}
        }
    
    def load_patient_demographics(self, demographics_path):
        """Load patient demographics with multiple attributes"""
        try:
            df = pd.read_csv(demographics_path)
            print(f"üìä Loaded demographics for {len(df)} patients")
            print(f"üìã Available columns: {list(df.columns)}")
            return df
        except Exception as e:
            print(f"‚ùå Error loading demographics: {e}")
            return None
    
    def analyze_feature_fairness(self, results_path, demographics_df, feature_column, feature_name=None):
        """
        Analyze fairness for any demographic feature
        
        Args:
            results_path: Path to experiment results JSON
            demographics_df: DataFrame with patient demographics
            feature_column: Column name for the demographic feature
            feature_name: Display name for the feature (optional)
        """
        
        if feature_name is None:
            feature_name = feature_column.title()
            
        print(f"\nüîç Analyzing {feature_name} Fairness")
        print("=" * 50)
        
        # Load experiment results
        try:
            with open(results_path, 'r') as f:
                results = json.load(f)
        except Exception as e:
            print(f"‚ùå Error loading results: {e}")
            return None
        
        # Group patients by feature
        feature_groups = {}
        group_performance = {}
        
        for patient_id, patient_results in results.items():
            # Find demographic info for this patient
            patient_demo = demographics_df[demographics_df['patient_id'] == patient_id]
            
            if len(patient_demo) == 0:
                continue
                
            feature_value = patient_demo[feature_column].iloc[0]
            
            if feature_value not in feature_groups:
                feature_groups[feature_value] = []
                
            feature_groups[feature_value].append({
                'patient_id': patient_id,
                'mse': patient_results.get('mse', 0),
                'mae': patient_results.get('mae', 0)
            })
        
        # Calculate performance for each group
        for group_value, patients in feature_groups.items():
            group_mse = np.mean([p['mse'] for p in patients])
            group_mae = np.mean([p['mae'] for p in patients])
            group_count = len(patients)
            
            group_performance[group_value] = {
                'count': group_count,
                'mse': group_mse,
                'mae': group_mae,
                'patients': patients
            }
            
            print(f"üè∑Ô∏è  {group_value}: {group_count} patients, MSE = {group_mse:.6f}")
        
        # Calculate fairness metrics
        mse_values = [perf['mse'] for perf in group_performance.values()]
        mae_values = [perf['mae'] for perf in group_performance.values()]
        
        # Fairness ratio (worst/best performance)
        mse_ratio = max(mse_values) / min(mse_values) if min(mse_values) > 0 else float('inf')
        mae_ratio = max(mae_values) / min(mae_values) if min(mae_values) > 0 else float('inf')
        
        # Coefficient of variation (std/mean)
        mse_cv = np.std(mse_values) / np.mean(mse_values) if np.mean(mse_values) > 0 else float('inf')
        mae_cv = np.std(mae_values) / np.mean(mae_values) if np.mean(mae_values) > 0 else float('inf')
        
        # Fairness level classification
        fairness_level = self.classify_fairness_level(mse_ratio)
        
        print(f"\nüìä {feature_name} Fairness Metrics:")
        print(f"‚öñÔ∏è  MSE Fairness Ratio: {mse_ratio:.3f}")
        print(f"‚öñÔ∏è  MAE Fairness Ratio: {mae_ratio:.3f}")
        print(f"üìà MSE Coefficient of Variation: {mse_cv:.3f}")
        print(f"üìà MAE Coefficient of Variation: {mae_cv:.3f}")
        print(f"üéØ Fairness Level: {fairness_level}")
        
        return {
            'feature_name': feature_name,
            'group_performance': group_performance,
            'mse_ratio': mse_ratio,
            'mae_ratio': mae_ratio,
            'mse_cv': mse_cv,
            'mae_cv': mae_cv,
            'fairness_level': fairness_level
        }
    
    def classify_fairness_level(self, ratio):
        """Classify fairness level based on ratio"""
        if ratio <= 1.10:
            return "Excellent"
        elif ratio <= 1.25:
            return "Good" 
        elif ratio <= 1.50:
            return "Acceptable"
        else:
            return "Poor"
    
    def analyze_all_features(self, results_path, demographics_df, features_to_analyze):
        """
        Analyze fairness across multiple demographic features
        
        Args:
            results_path: Path to experiment results
            demographics_df: DataFrame with patient demographics  
            features_to_analyze: List of column names to analyze
        """
        
        fairness_summary = {}
        
        print("üîç Multi-Feature Fairness Analysis")
        print("=" * 60)
        
        for feature_col in features_to_analyze:
            if feature_col in demographics_df.columns:
                result = self.analyze_feature_fairness(
                    results_path, demographics_df, feature_col
                )
                if result:
                    fairness_summary[feature_col] = result
            else:
                print(f"‚ö†Ô∏è  Feature '{feature_col}' not found in demographics")
        
        # Summary comparison
        print(f"\nüìã Fairness Summary Across All Features:")
        print("=" * 60)
        
        for feature, metrics in fairness_summary.items():
            print(f"{metrics['feature_name']:15} | Ratio: {metrics['mse_ratio']:5.2f} | Level: {metrics['fairness_level']}")
        
        # Find most/least fair features
        if fairness_summary:
            most_fair = min(fairness_summary.items(), key=lambda x: x[1]['mse_ratio'])
            least_fair = max(fairness_summary.items(), key=lambda x: x[1]['mse_ratio'])
            
            print(f"\nüèÜ Most Fair Feature: {most_fair[1]['feature_name']} (ratio: {most_fair[1]['mse_ratio']:.3f})")
            print(f"‚ö†Ô∏è  Least Fair Feature: {least_fair[1]['feature_name']} (ratio: {least_fair[1]['mse_ratio']:.3f})")
        
        return fairness_summary

# Initialize the multi-feature analyzer
multi_analyzer = MultiFairnessAnalyzer()

print("‚úÖ Multi-Feature Fairness Analyzer Ready!")
print("üéØ Supported features:")
for feature, mapping in multi_analyzer.supported_features.items():
    print(f"   ‚Ä¢ {feature}: {list(mapping.keys())}")

### Example: Analyze Age Group Fairness

In [None]:
# Example: Create a sample demographics file with multiple features
sample_demographics = pd.DataFrame({
    'patient_id': [f'patient_{i:03d}' for i in range(100)],
    'gender': np.random.choice(['male', 'female'], 100),
    'age': np.random.randint(18, 80, 100),
    'race': np.random.choice(['white', 'black', 'hispanic', 'asian'], 100),
    'bmi': np.random.normal(25, 5, 100),
    'disease_severity': np.random.choice(['mild', 'moderate', 'severe'], 100)
})

# Create age groups
def categorize_age(age):
    if age < 30:
        return 'young'
    elif age < 60:
        return 'middle'
    else:
        return 'old'

sample_demographics['age_group'] = sample_demographics['age'].apply(categorize_age)

# Create BMI categories
def categorize_bmi(bmi):
    if bmi < 18.5:
        return 'underweight'
    elif bmi < 25:
        return 'normal'
    elif bmi < 30:
        return 'overweight'
    else:
        return 'obese'

sample_demographics['bmi_category'] = sample_demographics['bmi'].apply(categorize_bmi)

print("üìä Sample Demographics Created:")
print(f"   ‚Ä¢ Age groups: {sample_demographics['age_group'].value_counts().to_dict()}")
print(f"   ‚Ä¢ Race distribution: {sample_demographics['race'].value_counts().to_dict()}")
print(f"   ‚Ä¢ BMI categories: {sample_demographics['bmi_category'].value_counts().to_dict()}")

# Save sample demographics (you can replace this with your actual data)
# sample_demographics.to_csv('/workspace/LLM-TIME/fairness/sample_demographics.csv', index=False)

print("üíæ Sample saved to sample_demographics.csv (uncomment to save)")

### Multi-Feature Fairness Training

In [None]:
class MultiFeatureFairnessLoss(nn.Module):
    """
    Fairness losses that work across multiple demographic features
    """
    
    def __init__(self, features_to_consider=['gender', 'age_group', 'race']):
        super().__init__()
        self.features = features_to_consider
    
    def multi_feature_parity_loss(self, predictions, demographic_features, lambda_fairness=0.1):
        """
        Enforce fairness across multiple demographic features simultaneously
        
        Args:
            predictions: Model predictions
            demographic_features: Dict with feature names as keys, labels as values
            lambda_fairness: Weight for fairness constraint
        """
        total_fairness_loss = 0.0
        
        for feature_name in self.features:
            if feature_name in demographic_features:
                feature_labels = demographic_features[feature_name]
                
                # Get unique groups for this feature
                unique_groups = torch.unique(feature_labels)
                
                if len(unique_groups) <= 1:
                    continue
                
                # Calculate mean prediction for each group
                group_means = []
                for group in unique_groups:
                    group_mask = (feature_labels == group)
                    if group_mask.sum() > 0:
                        group_mean = predictions[group_mask].mean()
                        group_means.append(group_mean)
                
                if len(group_means) > 1:
                    # Penalize variance across groups for this feature
                    group_means = torch.stack(group_means)
                    feature_fairness_loss = group_means.var()
                    total_fairness_loss += feature_fairness_loss
        
        return lambda_fairness * total_fairness_loss
    
    def intersectional_fairness_loss(self, predictions, demographic_features, lambda_fairness=0.1):
        """
        Consider intersectional fairness (e.g., young Black women vs old white men)
        """
        if len(demographic_features) < 2:
            return torch.tensor(0.0, device=predictions.device)
        
        # Create intersectional groups by combining features
        feature_names = list(demographic_features.keys())[:2]  # Use first 2 features
        
        feature1_labels = demographic_features[feature_names[0]]
        feature2_labels = demographic_features[feature_names[1]]
        
        # Combine features to create intersectional groups
        unique_groups = []
        group_means = []
        
        for val1 in torch.unique(feature1_labels):
            for val2 in torch.unique(feature2_labels):
                intersect_mask = (feature1_labels == val1) & (feature2_labels == val2)
                
                if intersect_mask.sum() > 0:  # Only if this intersection exists
                    group_mean = predictions[intersect_mask].mean()
                    group_means.append(group_mean)
                    unique_groups.append((val1.item(), val2.item()))
        
        if len(group_means) > 1:
            group_means = torch.stack(group_means)
            intersectional_loss = group_means.var()
            return lambda_fairness * intersectional_loss
        
        return torch.tensor(0.0, device=predictions.device)

def multi_feature_training_step(model, batch, optimizer, fairness_loss_fn, lambda_fairness=0.1):
    """
    Training step with multi-feature fairness constraints
    """
    
    # Extract batch data
    inputs = batch['inputs']
    targets = batch['targets']
    
    # Extract all demographic features
    demographic_features = {}
    for key, value in batch.items():
        if key not in ['inputs', 'targets', 'patient_id']:
            demographic_features[key] = value
    
    # Forward pass
    model.train()
    predictions = model(inputs)
    
    # Primary loss
    primary_loss = F.mse_loss(predictions, targets)
    
    # Multi-feature fairness losses
    parity_loss = fairness_loss_fn.multi_feature_parity_loss(predictions, demographic_features, lambda_fairness)
    intersectional_loss = fairness_loss_fn.intersectional_fairness_loss(predictions, demographic_features, lambda_fairness)
    
    # Total loss
    total_loss = primary_loss + parity_loss + intersectional_loss
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return {
        'total_loss': total_loss.item(),
        'primary_loss': primary_loss.item(),
        'multi_feature_parity_loss': parity_loss.item(),
        'intersectional_fairness_loss': intersectional_loss.item()
    }

print("‚úÖ Multi-Feature Fairness Training System Ready!")
print("üéØ Capabilities:")
print("   ‚Ä¢ Simultaneous fairness across multiple demographics")
print("   ‚Ä¢ Intersectional fairness (e.g., age + gender combinations)")
print("   ‚Ä¢ Flexible feature selection")
print("   ‚Ä¢ Comprehensive loss tracking")

### Usage Examples for Multi-Feature Fairness

In [None]:
# Example 1: Analyze fairness across multiple features
print("üîç Example 1: Multi-Feature Fairness Analysis")
print("-" * 50)

# Load your actual demographics (replace with real path)
# demographics_df = multi_analyzer.load_patient_demographics('/path/to/demographics.csv')

# For demo, use sample data
demographics_df = sample_demographics

# Analyze specific features
features_to_check = ['age_group', 'race', 'bmi_category', 'disease_severity']

# Example: If you have experiment results
# results_path = '/workspace/LLM-TIME/distillation_experiments/latest/results.json'
# fairness_results = multi_analyzer.analyze_all_features(results_path, demographics_df, features_to_check)

print("üìã To use with your real data:")
print("1. Load demographics: demographics_df = multi_analyzer.load_patient_demographics('your_file.csv')")
print("2. Analyze features: fairness_results = multi_analyzer.analyze_all_features(results_path, demographics_df, features_list)")

# Example 2: Multi-feature dataset
print("\nüîç Example 2: Multi-Feature Dataset")
print("-" * 50)

class MultiFeatureDataset(torch.utils.data.Dataset):
    """Dataset that includes multiple demographic features"""
    
    def __init__(self, data_path, demographics_df):
        self.data = self.load_time_series_data(data_path)
        self.demographics = demographics_df
        
        # Create mappings for categorical features
        self.feature_mappings = {
            'gender': {'male': 1, 'female': 0},
            'age_group': {'young': 0, 'middle': 1, 'old': 2},
            'race': {'white': 0, 'black': 1, 'hispanic': 2, 'asian': 3},
            'bmi_category': {'underweight': 0, 'normal': 1, 'overweight': 2, 'obese': 3},
            'disease_severity': {'mild': 0, 'moderate': 1, 'severe': 2}
        }
    
    def load_time_series_data(self, data_path):
        # Placeholder - replace with your data loading
        return torch.randn(100, 100, 1)  # 100 samples
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        inputs = self.data[idx]
        targets = torch.randn(1)  # Placeholder target
        
        # Get patient demographics
        patient_demo = self.demographics.iloc[idx % len(self.demographics)]
        
        # Convert categorical features to numeric
        batch_item = {
            'inputs': inputs,
            'targets': targets,
            'patient_id': patient_demo['patient_id']
        }
        
        # Add all demographic features
        for feature, mapping in self.feature_mappings.items():
            if feature in patient_demo:
                feature_value = patient_demo[feature]
                if feature_value in mapping:
                    batch_item[feature] = torch.tensor(mapping[feature_value], dtype=torch.float32)
                else:
                    batch_item[feature] = torch.tensor(0, dtype=torch.float32)  # Default
        
        return batch_item

# Create multi-feature data loader
# dataset = MultiFeatureDataset('path/to/data', demographics_df)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print("üìã Multi-feature dataset ready! Includes:")
for feature in ['gender', 'age_group', 'race', 'bmi_category', 'disease_severity']:
    print(f"   ‚Ä¢ {feature}")

# Example 3: Training with multiple fairness constraints
print("\nüîç Example 3: Multi-Feature Training")
print("-" * 50)

# Initialize multi-feature fairness loss
# fairness_loss_fn = MultiFeatureFairnessLoss(features_to_consider=['gender', 'age_group', 'race'])

# Training loop example
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         loss_dict = multi_feature_training_step(
#             model=model,
#             batch=batch,
#             optimizer=optimizer,
#             fairness_loss_fn=fairness_loss_fn,
#             lambda_fairness=0.1
#         )

print("üéØ Training tracks fairness across:")
print("   ‚Ä¢ Individual features (age, race, etc.)")
print("   ‚Ä¢ Intersectional combinations (young + Black, old + female, etc.)")
print("   ‚Ä¢ Overall fairness across all groups")

print("\n‚úÖ Multi-Feature Fairness System Complete!")
print("üöÄ You can now analyze and enforce fairness across any demographic features!")