# Data

In [1]:
import torch
import numpy as np
import os
import pickle
import pandas as pd
from torch.utils.data import Dataset, DataLoader

# Path to processed data
data_path = "/kaggle/input/mimic-embedding/processed_mimic_data"

class ProcessedMIMICDataset(Dataset):
    def __init__(self, data_path, data_format='pt'):
        """
        Load a processed MIMIC dataset
        
        Args:
            data_path: Path to the processed data directory
            data_format: Format to load ('pt' for PyTorch tensors, 'npy' for NumPy, 'pkl' for Pickle)
        """
        self.data_path = data_path
        self.data_format = data_format
        
        if data_format == 'pkl':
            # Load the pickle file
            with open(os.path.join(data_path, "all_data.pkl"), 'rb') as f:
                self.all_data = pickle.load(f)
            
            self.embeddings = self.all_data['embeddings']
            self.labels = self.all_data['labels']
            self.subject_ids = self.all_data['subject_ids']
            self.study_ids = self.all_data['study_ids']
            self.demographics = self.all_data['demographics']
            
        elif data_format == 'pt':
            # Load PyTorch tensors with weights_only=True to avoid security warnings
            self.embeddings = torch.load(os.path.join(data_path, "embeddings.pt"), weights_only=True)
            self.labels = torch.load(os.path.join(data_path, "labels.pt"), weights_only=True)
            
            # Load IDs from CSV
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            
            # Load demographics
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                self.demographics = pickle.load(f)
                
        elif data_format == 'npy':
            # Load NumPy arrays
            self.embeddings = np.load(os.path.join(data_path, "embeddings.npy"))
            self.labels = np.load(os.path.join(data_path, "labels.npy"))
            
            # Load IDs from CSV
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            
            # Load demographics
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                self.demographics = pickle.load(f)
                
        else:
            raise ValueError(f"Unsupported data format: {data_format}")
        
        print(f"Loaded dataset from {data_path} with {len(self.embeddings)} samples")
        
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        if self.data_format == 'pt':
            embedding = self.embeddings[idx]
            labels = self.labels[idx]
        else:
            embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
            labels = torch.tensor(self.labels[idx], dtype=torch.float32)
            
        return {
            'embedding': embedding,
            'labels': labels,
            'subject_id': self.subject_ids[idx],
            'study_id': self.study_ids[idx],
            'demographics': self.demographics[idx]
        }

# Load the datasets
train_dataset = ProcessedMIMICDataset(os.path.join(data_path, "train"), data_format='pt')
test_dataset = ProcessedMIMICDataset(os.path.join(data_path, "test"), data_format='pt')

# Create validation set from train (if needed)
def create_train_val_split(train_dataset, val_ratio=0.1, random_seed=42):
    """Split training dataset into train and validation sets"""
    dataset_size = len(train_dataset)
    val_size = int(val_ratio * dataset_size)
    train_size = dataset_size - val_size
    
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(random_seed)
    )
    
    print(f"Split train dataset: {train_size} training samples, {val_size} validation samples")
    
    return train_subset, val_subset

# Create train/val split (optional)
train_subset, val_subset = create_train_val_split(train_dataset)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Display dataset statistics
print(f"Training samples: {len(train_subset)}")
print(f"Validation samples: {len(val_subset)}")
print(f"Test samples: {len(test_dataset)}")

# Example of accessing a batch
sample_batch = next(iter(train_loader))
print(f"Sample batch shapes:")
print(f"  Embeddings: {sample_batch['embedding'].shape}")
print(f"  Labels: {sample_batch['labels'].shape}")
print(f"  Batch keys {sample_batch.keys()}")
print(f"  Batch[demographic] keys {sample_batch['demographics'].keys()}")

Loaded dataset from /kaggle/input/mimic-embedding/processed_mimic_data/train with 207314 samples
Loaded dataset from /kaggle/input/mimic-embedding/processed_mimic_data/test with 21591 samples
Split train dataset: 186583 training samples, 20731 validation samples
Training samples: 186583
Validation samples: 20731
Test samples: 21591
Sample batch shapes:
  Embeddings: torch.Size([128, 1376])
  Labels: torch.Size([128, 14])
  Batch keys dict_keys(['embedding', 'labels', 'subject_id', 'study_id', 'demographics'])
  Batch[demographic] keys dict_keys(['gender', 'insurance', 'race', 'anchor_age'])


# LWBC Test

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, random_split
import numpy as np
import random
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import roc_auc_score

# The MIMICClassifier from your code
class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        """
        Simple feed-forward neural network for multi-label classification
        
        Args:
            input_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            output_dim: Number of output classes
            dropout_rate: Dropout probability for regularization
        """
        super(MIMICClassifier, self).__init__()
        
        # Create the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        
        # Hidden layers
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Learning with Biased Committee implementation
class LWBC:
    def __init__(
        self,
        train_loader,
        val_loader,
        input_dim=1376,
        hidden_dims=[512, 256, 128],
        output_dim=14,
        committee_size=30,
        subset_size=0.7,
        alpha=0.02,
        lambda_kd=0.6,
        temperature=2.0,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        """
        LWBC implementation for MIMIC dataset.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            input_dim: Input embedding dimension
            hidden_dims: Hidden layer dimensions
            output_dim: Number of output classes
            committee_size: Number of classifiers in the committee
            subset_size: Size of subset for each committee member (proportion)
            alpha: Scaling parameter for weighting function
            lambda_kd: Balance parameter for knowledge distillation
            temperature: Temperature for knowledge distillation
            device: Device to run computations on
        """
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.committee_size = committee_size
        self.subset_size = subset_size
        self.alpha = alpha
        self.lambda_kd = lambda_kd
        self.temperature = temperature
        self.device = device
        
        # Initialize main classifier
        self.main_classifier = MIMICClassifier(
            input_dim, hidden_dims, output_dim
        ).to(device)
        
        # Initialize committee of auxiliary classifiers
        self.committee = [
            MIMICClassifier(input_dim, hidden_dims, output_dim).to(device)
            for _ in range(committee_size)
        ]
        
        # Initialize optimizers
        self.main_optimizer = optim.Adam(self.main_classifier.parameters(), lr=1e-3)
        self.committee_optimizers = [
            optim.Adam(classifier.parameters(), lr=1e-3)
            for classifier in self.committee
        ]
        
        # Loss function for multi-label classification
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')
        
        # Create subsets for each committee member
        self.create_subsets()
        
    def create_subsets(self):
        """Create random subsets of training data for each committee member"""
        # Count total samples in the training dataset
        dataset_size = len(self.train_loader.dataset)
        subset_count = int(dataset_size * self.subset_size)
        
        # Create subsets for each committee member
        self.subsets = []
        for _ in range(self.committee_size):
            # Sample indices with replacement
            subset_indices = random.choices(range(dataset_size), k=subset_count)
            self.subsets.append(sorted(subset_indices))
    
    def compute_sample_weights(self, outputs, labels):
        """
        Compute sample weights based on committee consensus
        
        Args:
            outputs: List of outputs from committee members
            labels: Ground truth labels
            
        Returns:
            Tensor of sample weights
        """
        batch_size = labels.size(0)
        correct_count = torch.zeros(batch_size, device=self.device)
        
        # Count correct predictions for each sample across committee
        for output in outputs:
            preds = (torch.sigmoid(output) > 0.5).float()
            correct = (preds == labels).all(dim=1).float()
            correct_count += correct
        
        # Calculate proportion of committee members that predicted correctly
        correct_proportion = correct_count / self.committee_size
        
        # Apply weighting function w(x) = 1 / (proportion + alpha)
        weights = 1.0 / (correct_proportion + self.alpha)
        
        return weights
    
    def knowledge_distillation_loss(self, committee_output, main_output):
        """
        Calculate knowledge distillation loss
        
        Args:
            committee_output: Output from committee member
            main_output: Output from main classifier
            
        Returns:
            KD loss
        """
        # Apply temperature scaling
        committee_logits = committee_output / self.temperature
        main_logits = main_output / self.temperature
        
        # KL divergence between softmax distributions
        committee_probs = torch.sigmoid(committee_logits)
        main_probs = torch.sigmoid(main_logits)
        
        # Calculate KL divergence for each output dimension
        kl_div = main_probs * torch.log(main_probs / (committee_probs + 1e-8) + 1e-8) + \
                (1 - main_probs) * torch.log((1 - main_probs) / (1 - committee_probs + 1e-8) + 1e-8)
                
        return kl_div.mean()
    
    def train_committee_warmup(self, num_epochs=5):
        """Warm-up training for committee members"""
        print("Starting committee warm-up training...")
        
        # Create sample masks for committee members
        committee_masks = []
        dataset_size = len(self.train_loader.dataset)
        for subset_indices in self.subsets:
            mask = torch.zeros(dataset_size, dtype=torch.bool)
            mask[subset_indices] = True
            committee_masks.append(mask)
        
        for classifier_idx, classifier in enumerate(self.committee):
            classifier.train()
            optimizer = self.committee_optimizers[classifier_idx]
            mask = committee_masks[classifier_idx]
            
            print(f"Training committee member {classifier_idx+1}/{self.committee_size}")
            
            for epoch in range(num_epochs):
                total_loss = 0.0
                num_batches = 0
                
                # Create a subset DataLoader
                subset_dataset = Subset(self.train_loader.dataset, self.subsets[classifier_idx])
                subset_loader = DataLoader(
                    subset_dataset, 
                    batch_size=self.train_loader.batch_size,
                    shuffle=True,
                    num_workers=0  # Reduce if having memory issues
                )
                
                for batch in subset_loader:
                    # Extract data
                    X = batch['embedding'].to(self.device)
                    y = batch['labels'].to(self.device)
                    
                    # Forward pass
                    outputs = classifier(X)
                    loss = self.criterion(outputs, y).mean()
                    
                    # Backward pass and optimize
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    num_batches += 1
                
                avg_loss = total_loss / max(1, num_batches)
                print(f"  Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    def train_epoch(self):
        """Train for one epoch"""
        self.main_classifier.train()
        for classifier in self.committee:
            classifier.train()
        
        total_main_loss = 0.0
        total_committee_loss = 0.0
        num_batches = 0
        
        # Calculate sample weights for all training data
        print("Computing sample weights based on committee consensus...")
        sample_weights_dict = {}  # Map batch_idx -> sample_weights
        
        # Create mapping from global index to subset inclusion
        dataset_size = len(self.train_loader.dataset)
        subset_masks = []
        for subset in self.subsets:
            mask = torch.zeros(dataset_size, dtype=torch.bool)
            mask[subset] = True
            subset_masks.append(mask)
            
        # Get weights for all samples
        global_sample_weights = torch.ones(dataset_size, device=self.device)
        
        # Since we can't fit all data at once, process in batches
        batch_size = self.train_loader.batch_size
        
        # First, get predictions from committee for all samples
        with torch.no_grad():
            all_committee_preds = [[] for _ in range(self.committee_size)]
            all_labels = []
            batch_offsets = []
            
            # Get predictions for the entire dataset
            offset = 0
            for batch in tqdm(self.train_loader, desc="Getting committee predictions"):
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                batch_offsets.append(offset)
                
                # Get predictions from each committee member
                for i, classifier in enumerate(self.committee):
                    preds = classifier(X)
                    all_committee_preds[i].append(preds)
                
                all_labels.append(y)
                offset += len(X)
            
            # Determine weights for each sample
            for batch_idx, offset in enumerate(batch_offsets):
                y = all_labels[batch_idx]
                batch_size = y.size(0)
                
                # Count correct predictions for each sample
                correct_count = torch.zeros(batch_size, device=self.device)
                for i in range(self.committee_size):
                    output = all_committee_preds[i][batch_idx]
                    preds = (torch.sigmoid(output) > 0.5).float()
                    correct = (preds == y).all(dim=1).float()
                    correct_count += correct
                
                # Calculate proportion of committee members that predicted correctly
                correct_proportion = correct_count / self.committee_size
                
                # Apply weighting function w(x) = 1 / (proportion + alpha)
                weights = 1.0 / (correct_proportion + self.alpha)
                
                # Store for later use
                sample_weights_dict[batch_idx] = weights
        
        # Main training loop
        print("Training main classifier with weighted samples...")
        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc="Training main classifier")):
            X = batch['embedding'].to(self.device)
            y = batch['labels'].to(self.device)
            batch_size = X.size(0)
            
            # Get sample weights for this batch
            sample_weights = sample_weights_dict[batch_idx]
            
            # Train main classifier with weighted loss
            self.main_optimizer.zero_grad()
            main_outputs = self.main_classifier(X)
            main_loss = (self.criterion(main_outputs, y) * sample_weights.unsqueeze(1)).mean()
            main_loss.backward()
            self.main_optimizer.step()
            
            total_main_loss += main_loss.item()
            num_batches += 1
        
        # Train committee members with knowledge distillation
        print("Training committee with knowledge distillation...")
        committee_loss_sum = 0.0
        committee_batches = 0
        
        for classifier_idx, classifier in enumerate(self.committee):
            classifier.train()
            optimizer = self.committee_optimizers[classifier_idx]
            
            # Create subset dataset for this committee member
            subset_dataset = Subset(self.train_loader.dataset, self.subsets[classifier_idx])
            subset_loader = DataLoader(
                subset_dataset,
                batch_size=self.train_loader.batch_size,
                shuffle=True,
                num_workers=0
            )
            
            # Create complement dataset for knowledge distillation
            complement_indices = [i for i in range(dataset_size) if i not in self.subsets[classifier_idx]]
            complement_dataset = Subset(self.train_loader.dataset, complement_indices)
            complement_loader = DataLoader(
                complement_dataset,
                batch_size=self.train_loader.batch_size,
                shuffle=True,
                num_workers=0
            )
            
            # Train on subset with supervised loss
            for batch in subset_loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = classifier(X)
                ce_loss = self.criterion(outputs, y).mean()
                ce_loss.backward()
                optimizer.step()
            
            # Train on complement with knowledge distillation
            for batch in complement_loader:
                X = batch['embedding'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                committee_output = classifier(X)
                
                # Get main classifier output
                with torch.no_grad():
                    main_output = self.main_classifier(X)
                
                # Knowledge distillation loss
                kd_loss = self.knowledge_distillation_loss(committee_output, main_output)
                kd_loss.backward()
                optimizer.step()
                
                committee_loss_sum += kd_loss.item()
                committee_batches += 1
        
        avg_committee_loss = committee_loss_sum / max(1, committee_batches)
        
        return total_main_loss / num_batches, avg_committee_loss
    
    def evaluate(self, loader):
        """Evaluate model on validation or test set"""
        self.main_classifier.eval()
        all_outputs = []
        all_labels = []
        
        with torch.no_grad():
            for batch in loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                outputs = self.main_classifier(X)
                all_outputs.append(outputs.cpu())
                all_labels.append(y.cpu())
        
        all_outputs = torch.cat(all_outputs, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        
        # Compute ROC-AUC score
        all_outputs_sigmoid = torch.sigmoid(all_outputs).numpy()
        all_labels_numpy = all_labels.numpy()
        
        # Handle case where a class might have all 0s or all 1s
        auc_scores = []
        for i in range(self.output_dim):
            if len(np.unique(all_labels_numpy[:, i])) > 1:
                auc_scores.append(roc_auc_score(all_labels_numpy[:, i], all_outputs_sigmoid[:, i]))
            else:
                auc_scores.append(0.5)  # Default for single-class
        
        macro_auc = np.mean(auc_scores)
        
        # Compute accuracy
        predictions = (torch.sigmoid(all_outputs) > 0.5).float()
        accuracy = (predictions == all_labels).float().mean().item()
        
        return accuracy, macro_auc, auc_scores
    
    def evaluate_by_demographic(self, loader, demographic_attr='gender'):
        """Evaluate model performance across demographic groups"""
        self.main_classifier.eval()
        results = {}
        
        # Initialize dictionaries to store outputs and labels for each group
        group_outputs = {}
        group_labels = {}
        
        with torch.no_grad():
            for batch in loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                # Get demographic attributes
                demographic_values = batch['demographics'][demographic_attr]
                
                # Get model outputs
                outputs = self.main_classifier(X)
                
                # Group by demographic
                for i, demo_val in enumerate(demographic_values):
                    demo_val = demo_val.item() if isinstance(demo_val, torch.Tensor) else demo_val
                    if demo_val not in group_outputs:
                        group_outputs[demo_val] = []
                        group_labels[demo_val] = []
                    
                    group_outputs[demo_val].append(outputs[i:i+1].cpu())
                    group_labels[demo_val].append(y[i:i+1].cpu())
        
        # Calculate metrics for each group
        for group in group_outputs:
            group_output_tensor = torch.cat(group_outputs[group], dim=0)
            group_label_tensor = torch.cat(group_labels[group], dim=0)
            
            # Compute ROC-AUC
            group_output_sigmoid = torch.sigmoid(group_output_tensor).numpy()
            group_label_numpy = group_label_tensor.numpy()
            
            # Handle case where a class might have all 0s or all 1s
            auc_scores = []
            for i in range(self.output_dim):
                if len(np.unique(group_label_numpy[:, i])) > 1:
                    auc_scores.append(roc_auc_score(group_label_numpy[:, i], group_output_sigmoid[:, i]))
                else:
                    auc_scores.append(0.5)  # Default for single-class
            
            macro_auc = np.mean(auc_scores)
            
            # Compute accuracy
            predictions = (torch.sigmoid(group_output_tensor) > 0.5).float()
            accuracy = (predictions == group_label_tensor).float().mean().item()
            
            results[group] = {
                'accuracy': accuracy,
                'macro_auc': macro_auc,
                'count': len(group_labels[group])
            }
        
        return results
    
    def train(self, num_epochs=20, warmup_epochs=5):
        """Full training procedure"""
        # Warm-up training for committee
        self.train_committee_warmup(num_epochs=warmup_epochs)
        
        # Main training loop
        best_auc = 0.0
        best_model = None
        history = {
            'train_loss': [],
            'committee_loss': [],
            'val_accuracy': [],
            'val_auc': []
        }
        
        for epoch in range(num_epochs):
            # Train one epoch
            train_loss, committee_loss = self.train_epoch()
            
            # Evaluate
            val_accuracy, val_auc, _ = self.evaluate(self.val_loader)
            
            # Store history
            history['train_loss'].append(train_loss)
            history['committee_loss'].append(committee_loss)
            history['val_accuracy'].append(val_accuracy)
            history['val_auc'].append(val_auc)
            
            # Print metrics
            print(f"Epoch {epoch+1}/{num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Committee Loss: {committee_loss:.4f}")
            print(f"  Val Accuracy: {val_accuracy:.4f}, Val AUC: {val_auc:.4f}")
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                best_model = deepcopy(self.main_classifier.state_dict())
                print(f"  New best model with Val AUC: {val_auc:.4f}")
        
        # Load best model
        if best_model is not None:
            self.main_classifier.load_state_dict(best_model)
        
        return history
    
    def save_model(self, path):
        """Save the model to disk"""
        torch.save(self.main_classifier.state_dict(), path)
    
    def load_model(self, path):
        """Load the model from disk"""
        self.main_classifier.load_state_dict(torch.load(path, map_location=self.device))


In [None]:
# Initialize LWBC
lwbc = LWBC(
    train_loader=train_loader,
    val_loader=val_loader,
    input_dim=1376,
    hidden_dims=[512, 256, 128],
    output_dim=14,       
    committee_size=3,
    subset_size=0.7,
    alpha=0.02,
    lambda_kd=0.6
)

# Train the model
history = lwbc.train(num_epochs=20, warmup_epochs=5)

# Evaluate on test set
test_accuracy, test_auc, _ = lwbc.evaluate(test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}, Test AUC: {test_auc:.4f}")

# Evaluate across demographic groups
gender_results = lwbc.evaluate_by_demographic(test_loader, demographic_attr='gender')
for gender, metrics in gender_results.items():
    print(f"Gender {gender}: Accuracy: {metrics['accuracy']:.4f}, AUC: {metrics['macro_auc']:.4f}, Count: {metrics['count']}")

# Save model
lwbc.save_model('lwbc_mimic_model.pt')

Starting committee warm-up training...
Training committee member 1/3
  Epoch 1/5, Loss: 0.2999
  Epoch 2/5, Loss: 0.2573
  Epoch 3/5, Loss: 0.2545
  Epoch 4/5, Loss: 0.2527
  Epoch 5/5, Loss: 0.2516
Training committee member 2/3
  Epoch 1/5, Loss: 0.3003
  Epoch 2/5, Loss: 0.2572
  Epoch 3/5, Loss: 0.2547
  Epoch 4/5, Loss: 0.2528
  Epoch 5/5, Loss: 0.2518
Training committee member 3/3
  Epoch 1/5, Loss: 0.3007
  Epoch 2/5, Loss: 0.2573
  Epoch 3/5, Loss: 0.2546
  Epoch 4/5, Loss: 0.2530
  Epoch 5/5, Loss: 0.2517
Computing sample weights based on committee consensus...


Getting committee predictions: 100%|██████████| 1458/1458 [00:12<00:00, 113.98it/s]


Training main classifier with weighted samples...


Training main classifier: 100%|██████████| 1458/1458 [00:16<00:00, 91.07it/s]


Training committee with knowledge distillation...


# Models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
import time
from tqdm import tqdm
import os

# Define the neural network model
class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        """
        Simple feed-forward neural network for multi-label classification
        
        Args:
            input_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            output_dim: Number of output classes
            dropout_rate: Dropout probability for regularization
        """
        super(MIMICClassifier, self).__init__()
        
        # Create the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        
        # Hidden layers
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        # No activation function here since we're using BCEWithLogitsLoss
        # which combines sigmoid and binary cross entropy
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)



In [11]:
class LinearVAE(nn.Module):
    def __init__(self, input_dim=1376, hidden_dim=512, latent_dim=128):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Use sigmoid if your embeddings are normalized between 0-1
                          # or remove if using other normalization
        )
        
    def encode(self, x):
        hidden = self.encoder(x)
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar, z
    
    def get_latent(self, x):
        """Get latent representation without reconstruction"""
        mu, logvar = self.encode(x)
        return self.reparameterize(mu, logvar)

def VAE_LOSS(reconstructed, x, mu, logvar, kld_weight=0.005):
    """
    VAE loss with KL divergence and reconstruction loss
    
    Args:
        reconstructed: Reconstructed input from decoder
        x: Original input 
        mu: Mean from encoder
        logvar: Log variance from encoder
        kld_weight: Weight for KL divergence term
    """
    # Reconstruction loss (MSE or BCE depending on your data)
    recon_loss = F.mse_loss(reconstructed, x, reduction='sum')
    
    # KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    loss = recon_loss + kld_weight * kld_loss
    
    return loss, recon_loss, kld_loss
    
    

# Training

In [31]:
# Function to create "target" labels for adversarial debiasing
def create_uniform_targets(batch_size, num_classes, device):
    """
    Creates uniform probability targets (maximum uncertainty)
    For binary classification this is 0.5, for multi-class it's 1/num_classes
    """
    return torch.ones(batch_size, num_classes, device=device) / num_classes

# Training function
def train_models(train_loader, val_loader, input_dim, output_dim, num_epochs=20, phase1_epochs=5, 
                 lambda_fair=1.0, lambda_adv=2.0, lambda_recon=0.5):
    """
    Training pipeline for adversarial debiasing without demographic labels:
    - Phase 1: Train attacker on original embeddings to learn inherent biases
    - Phase 2: Adversarial training where generator tries to fool attacker
    
    Args:
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        input_dim: Dimension of the embeddings
        output_dim: Dimension of the task labels
        num_epochs: Total number of epochs
        phase1_epochs: Number of epochs for pre-training attacker
        lambda_fair: Weight for fair model loss
        lambda_adv: Weight for adversarial loss
        lambda_recon: Weight for reconstruction loss
    """
    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize models with provided classes
    hidden_dims = [512, 256, 128]
    
    # The attacker model will learn biases in original embeddings
    attacker_model = MIMICClassifier(input_dim, hidden_dims, output_dim)
    
    # The fair model will only be trained on debiased embeddings
    fair_model = MIMICClassifier(input_dim, hidden_dims, output_dim)
    
    # The generator creates debiased embeddings
    generator = LinearVAE(input_dim=input_dim, hidden_dim=512, latent_dim=128)
    
    # Move models to device
    attacker_model = attacker_model.to(device)
    fair_model = fair_model.to(device)
    generator = generator.to(device)
    
    # Define optimizers
    attacker_optimizer = optim.Adam(attacker_model.parameters(), lr=1e-4)
    fair_optimizer = optim.Adam(fair_model.parameters(), lr=1e-4)
    generator_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
    
    # Loss function for classification
    bce_loss = nn.BCEWithLogitsLoss()
    
    best_val_loss = float('inf')
    history = {
        'attacker_loss': [], 'fair_loss': [], 'generator_loss': [],
        'val_attacker_loss': [], 'val_fair_loss': []
    }
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # Phase 1: Pre-training - ONLY train the attacker to be biased
        if epoch < phase1_epochs:
            print(f"Epoch {epoch+1}/{num_epochs} - Phase 1: Pre-training attacker only")
            
            # Set models to appropriate modes
            generator.eval()  # Generator not trained in phase 1
            fair_model.eval()  # Fair model not trained in phase 1
            attacker_model.train()  # Only attacker is trained
            
            epoch_attacker_loss = 0.0
            att_batches = 0
            
            for batch in tqdm(train_loader, desc="Training Attacker (on original data)"):
                embeddings = batch['embedding'].to(device)
                labels = batch['labels'].to(device)
                
                # Train attacker with original embeddings only to make it biased
                attacker_optimizer.zero_grad()
                pred = attacker_model(embeddings)
                loss = bce_loss(pred, labels)
                loss.backward()
                attacker_optimizer.step()
                
                epoch_attacker_loss += loss.item()
                att_batches += 1
            
            # Save metrics
            avg_attacker_loss = epoch_attacker_loss / att_batches
            history['attacker_loss'].append(avg_attacker_loss)
            history['fair_loss'].append(0)  # Fair model not trained
            history['generator_loss'].append(0)  # Generator not trained
            
            print(f"Attacker Loss (original data): {avg_attacker_loss:.4f}")
            
            # Evaluate attacker's bias on validation data
            attacker_model.eval()
            val_attacker_loss = 0.0
            val_batches = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    embeddings = batch['embedding'].to(device)
                    labels = batch['labels'].to(device)
                    
                    # Check attacker performance (should be getting better/more biased)
                    attacker_pred = attacker_model(embeddings)
                    val_loss = bce_loss(attacker_pred, labels).item()
                    
                    val_attacker_loss += val_loss
                    val_batches += 1
            
            val_attacker_loss /= val_batches
            history['val_attacker_loss'].append(val_attacker_loss)
            history['val_fair_loss'].append(0)  # Fair model not evaluated
            
            print(f"Validation - Attacker Loss: {val_attacker_loss:.4f}")
        
        # Phase 2: Adversarial training
        else:
            print(f"Epoch {epoch+1}/{num_epochs} - Phase 2: Adversarial training")
            
            # PHASE 2.1: Train Attacker (discriminator) - Only on original data
            attacker_model.train()
            fair_model.eval()
            generator.eval()
            
            epoch_attacker_loss = 0.0
            batches = 0
            
            for batch in tqdm(train_loader, desc="Training Attacker"):
                embeddings = batch['embedding'].to(device)
                labels = batch['labels'].to(device)
                
                attacker_optimizer.zero_grad()
                
                # Train on real data to maintain bias
                pred_real = attacker_model(embeddings)
                attacker_loss = bce_loss(pred_real, labels)
                attacker_loss.backward()
                attacker_optimizer.step()
                
                epoch_attacker_loss += attacker_loss.item()
                batches += 1
            
            avg_attacker_loss = epoch_attacker_loss / batches
            history['attacker_loss'].append(avg_attacker_loss)
            
            # PHASE 2.2: Train Fair Model - Only on generated data
            attacker_model.eval()
            fair_model.train()
            generator.eval()
            
            epoch_fair_loss = 0.0
            batches = 0
            
            for batch in tqdm(train_loader, desc="Training Fair Model"):
                embeddings = batch['embedding'].to(device)
                labels = batch['labels'].to(device)
                
                fair_optimizer.zero_grad()
                
                # Train ONLY on reconstructed embeddings
                with torch.no_grad():
                    reconstructed, _, _, _ = generator(embeddings)
                
                pred = fair_model(reconstructed)
                fair_loss = bce_loss(pred, labels)
                fair_loss.backward()
                fair_optimizer.step()
                
                epoch_fair_loss += fair_loss.item()
                batches += 1
            
            avg_fair_loss = epoch_fair_loss / batches
            history['fair_loss'].append(avg_fair_loss)
            
            # PHASE 2.3: Train Generator (adversarial)
            attacker_model.eval()
            fair_model.eval()
            generator.train()
            
            epoch_generator_loss = 0.0
            batches = 0
            
            for batch in tqdm(train_loader, desc="Training Generator"):
                embeddings = batch['embedding'].to(device)
                labels = batch['labels'].to(device)
                batch_size = embeddings.size(0)
                
                generator_optimizer.zero_grad()
                
                # Get reconstructed embeddings
                reconstructed, mu, logvar, latent = generator(embeddings)
                
                # 1. Reconstruction loss - should reconstruct well enough
                recon_loss, bce, kld = VAE_LOSS(reconstructed, embeddings, mu, logvar, kld_weight=0.005)
                weighted_recon_loss = lambda_recon * recon_loss
                
                # 2. Adversarial loss - attacker should NOT predict well
                pred_attacker = attacker_model(reconstructed)
                
                # Create "uncertain" targets (e.g., 0.5 for binary classification)
                uncertain_targets = create_uniform_targets(batch_size, output_dim, device)
                
                # Generator wants attacker predictions to be uncertain
                # We use KL divergence between uniform distribution and attacker predictions
                # Higher KL means predictions are less uniform (more certain)
                adv_loss = lambda_adv * F.kl_div(
                    F.log_softmax(pred_attacker, dim=1),
                    uncertain_targets,
                    reduction='batchmean'
                )
                
                # 3. Fair model loss - fair model SHOULD predict well
                pred_fair = fair_model(reconstructed)
                fair_model_loss = lambda_fair * bce_loss(pred_fair, labels)
                
                # Total generator loss
                # Note: we want to minimize adv_loss (KL divergence), 
                # so we use it directly rather than negating it
                generator_loss = weighted_recon_loss + adv_loss + fair_model_loss
                generator_loss.backward()
                generator_optimizer.step()
                
                epoch_generator_loss += generator_loss.item()
                batches += 1
                
                # Debug outputs
                if batches % 100 == 0:
                    print(f"Batch {batches}: Recon Loss: {weighted_recon_loss.item():.4f}, "
                          f"Adv Loss: {adv_loss.item():.4f}, "
                          f"Fair Loss: {fair_model_loss.item():.4f}")
            
            avg_generator_loss = epoch_generator_loss / batches
            history['generator_loss'].append(avg_generator_loss)
            
            print(f"Epoch Summary - Attacker Loss: {avg_attacker_loss:.4f}, "
                  f"Fair Loss: {avg_fair_loss:.4f}, "
                  f"Generator Loss: {avg_generator_loss:.4f}")
        
            # Validation phase
            attacker_model.eval()
            fair_model.eval()
            generator.eval()
            val_attacker_loss_orig = 0.0
            val_attacker_loss_recon = 0.0
            val_fair_loss = 0.0
            val_batches = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    embeddings = batch['embedding'].to(device)
                    labels = batch['labels'].to(device)
                    
                    # Validate attacker on original embeddings (baseline)
                    attacker_pred_orig = attacker_model(embeddings)
                    attacker_orig_loss = bce_loss(attacker_pred_orig, labels).item()
                    
                    # Generate debiased embeddings
                    reconstructed, _, _, _ = generator(embeddings)
                    
                    # Validate attacker on reconstructed (debiased) embeddings
                    attacker_pred_recon = attacker_model(reconstructed)
                    attacker_recon_loss = bce_loss(attacker_pred_recon, labels).item()
                    
                    # Validate fair model on generated embeddings
                    fair_pred = fair_model(reconstructed)
                    val_fair_loss += bce_loss(fair_pred, labels).item()
                    
                    val_attacker_loss_orig += attacker_orig_loss
                    val_attacker_loss_recon += attacker_recon_loss
                    val_batches += 1
            
            val_attacker_loss_orig /= val_batches
            val_attacker_loss_recon /= val_batches
            val_fair_loss /= val_batches
            
            history['val_attacker_loss'].append(val_attacker_loss_recon)
            history['val_fair_loss'].append(val_fair_loss)
            
            # Debiasing effectiveness = difference between attacker performance on original vs. debiased
            debiasing_effect = val_attacker_loss_recon - val_attacker_loss_orig
            
            print(f"Validation - Attacker Loss (orig): {val_attacker_loss_orig:.4f}")
            print(f"Validation - Attacker Loss (recon): {val_attacker_loss_recon:.4f}")
            print(f"Validation - Fair Loss: {val_fair_loss:.4f}")
            print(f"Validation - Debiasing Effect: {debiasing_effect:.4f} (higher is better)")
            
            # Save best model based on validation metrics
            # We want high debiasing effect and low fair loss
            fairness_score = debiasing_effect - val_fair_loss
            
            if fairness_score > best_val_loss:
                best_val_loss = fairness_score
                torch.save({
                    'generator': generator.state_dict(),
                    'fair_model': fair_model.state_dict(),
                    'attacker_model': attacker_model.state_dict(),
                    'epoch': epoch,
                    'fairness_score': fairness_score
                }, 'best_adversarial_models.pt')
                print(f"Saved best model with fairness score: {fairness_score:.4f}")
                
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch completed in {epoch_time:.2f} seconds")
        print("-" * 80)
    
    return history, generator, attacker_model, fair_model

In [38]:
# Updated Training Call
history, generator, attacker_model, fair_model = train_models(
    train_loader=train_loader,
    val_loader=val_loader,
    input_dim=1376,
    output_dim=14,
    num_epochs=30,          # ⬆️ More time for Phase 2 to work
    phase1_epochs=5,        # Same (enough for attacker to become biased)
    lambda_fair=5.0,        # ⬆️ Help fair model learn more strongly
    lambda_adv=5.0,        # ⬆️ Stronger pressure on generator to fool attacker
    lambda_recon=0.2        # ⬇️ Less emphasis on perfect reconstruction
)


Epoch 1/30 - Phase 1: Pre-training attacker only


Training Attacker (on original data): 100%|██████████| 1458/1458 [00:07<00:00, 183.32it/s]


Attacker Loss (original data): 0.4917
Validation - Attacker Loss: 0.2913
Epoch completed in 8.40 seconds
--------------------------------------------------------------------------------
Epoch 2/30 - Phase 1: Pre-training attacker only


Training Attacker (on original data): 100%|██████████| 1458/1458 [00:07<00:00, 185.92it/s]


Attacker Loss (original data): 0.2738
Validation - Attacker Loss: 0.2546
Epoch completed in 8.29 seconds
--------------------------------------------------------------------------------
Epoch 3/30 - Phase 1: Pre-training attacker only


Training Attacker (on original data): 100%|██████████| 1458/1458 [00:08<00:00, 180.97it/s]


Attacker Loss (original data): 0.2605
Validation - Attacker Loss: 0.2539
Epoch completed in 8.51 seconds
--------------------------------------------------------------------------------
Epoch 4/30 - Phase 1: Pre-training attacker only


Training Attacker (on original data): 100%|██████████| 1458/1458 [00:08<00:00, 180.51it/s]


Attacker Loss (original data): 0.2578
Validation - Attacker Loss: 0.2528
Epoch completed in 8.52 seconds
--------------------------------------------------------------------------------
Epoch 5/30 - Phase 1: Pre-training attacker only


Training Attacker (on original data): 100%|██████████| 1458/1458 [00:08<00:00, 167.36it/s]


Attacker Loss (original data): 0.2562
Validation - Attacker Loss: 0.2514
Epoch completed in 9.15 seconds
--------------------------------------------------------------------------------
Epoch 6/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 188.93it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 158.45it/s]
Training Generator:   8%|▊         | 115/1458 [00:01<00:14, 91.12it/s]

Batch 100: Recon Loss: 73337.8281, Adv Loss: 14.0440, Fair Loss: 4.4847


Training Generator:  15%|█▍        | 215/1458 [00:02<00:13, 91.50it/s]

Batch 200: Recon Loss: 70462.2734, Adv Loss: 9.4831, Fair Loss: 3.7010


Training Generator:  22%|██▏       | 315/1458 [00:03<00:12, 92.80it/s]

Batch 300: Recon Loss: 69795.2344, Adv Loss: 9.1542, Fair Loss: 3.8428


Training Generator:  28%|██▊       | 415/1458 [00:04<00:10, 94.96it/s]

Batch 400: Recon Loss: 68358.5859, Adv Loss: 8.7667, Fair Loss: 4.1666


Training Generator:  35%|███▌      | 515/1458 [00:05<00:09, 95.49it/s]

Batch 500: Recon Loss: 69894.4531, Adv Loss: 9.4139, Fair Loss: 4.1904


Training Generator:  42%|████▏     | 615/1458 [00:06<00:08, 93.96it/s]

Batch 600: Recon Loss: 66504.1484, Adv Loss: 9.9384, Fair Loss: 3.7764


Training Generator:  49%|████▉     | 715/1458 [00:07<00:07, 93.28it/s]

Batch 700: Recon Loss: 67403.1484, Adv Loss: 10.3599, Fair Loss: 3.4607


Training Generator:  56%|█████▌    | 815/1458 [00:08<00:06, 94.33it/s]

Batch 800: Recon Loss: 67116.4688, Adv Loss: 10.2478, Fair Loss: 4.5995


Training Generator:  63%|██████▎   | 915/1458 [00:09<00:05, 94.89it/s]

Batch 900: Recon Loss: 66047.5547, Adv Loss: 10.1918, Fair Loss: 4.0953


Training Generator:  70%|██████▉   | 1014/1458 [00:10<00:04, 91.43it/s]

Batch 1000: Recon Loss: 67062.2656, Adv Loss: 10.6366, Fair Loss: 4.8465


Training Generator:  76%|███████▋  | 1114/1458 [00:12<00:03, 93.47it/s]

Batch 1100: Recon Loss: 65512.7266, Adv Loss: 10.5845, Fair Loss: 4.0545


Training Generator:  83%|████████▎ | 1214/1458 [00:13<00:02, 93.47it/s]

Batch 1200: Recon Loss: 66833.9453, Adv Loss: 10.4241, Fair Loss: 5.2684


Training Generator:  90%|█████████ | 1314/1458 [00:14<00:01, 94.37it/s]

Batch 1300: Recon Loss: 65175.9805, Adv Loss: 10.7169, Fair Loss: 4.5109


Training Generator:  97%|█████████▋| 1414/1458 [00:15<00:00, 93.65it/s]

Batch 1400: Recon Loss: 68550.0703, Adv Loss: 10.6805, Fair Loss: 4.2437


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 92.88it/s]


Epoch Summary - Attacker Loss: 0.2546, Fair Loss: 0.5604, Generator Loss: 68918.8549
Validation - Attacker Loss (orig): 0.2508
Validation - Attacker Loss (recon): 0.4873
Validation - Fair Loss: 0.9191
Validation - Debiasing Effect: 0.2365 (higher is better)
Epoch completed in 33.40 seconds
--------------------------------------------------------------------------------
Epoch 7/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 191.44it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 159.63it/s]
Training Generator:   7%|▋         | 109/1458 [00:01<00:14, 92.37it/s]

Batch 100: Recon Loss: 65954.4141, Adv Loss: 10.1342, Fair Loss: 1.4048


Training Generator:  14%|█▍        | 209/1458 [00:02<00:13, 92.23it/s]

Batch 200: Recon Loss: 68449.3984, Adv Loss: 10.2169, Fair Loss: 1.3565


Training Generator:  21%|██        | 309/1458 [00:03<00:12, 92.08it/s]

Batch 300: Recon Loss: 64915.6445, Adv Loss: 10.1943, Fair Loss: 1.2914


Training Generator:  28%|██▊       | 409/1458 [00:04<00:11, 92.19it/s]

Batch 400: Recon Loss: 64383.1680, Adv Loss: 10.4197, Fair Loss: 1.3862


Training Generator:  35%|███▍      | 509/1458 [00:05<00:10, 93.33it/s]

Batch 500: Recon Loss: 67481.0391, Adv Loss: 10.7831, Fair Loss: 1.4104


Training Generator:  42%|████▏     | 609/1458 [00:06<00:09, 91.31it/s]

Batch 600: Recon Loss: 69996.4922, Adv Loss: 10.8170, Fair Loss: 1.3213


Training Generator:  49%|████▉     | 718/1458 [00:07<00:07, 92.62it/s]

Batch 700: Recon Loss: 64495.6523, Adv Loss: 10.7023, Fair Loss: 1.2866


Training Generator:  56%|█████▌    | 818/1458 [00:08<00:06, 93.29it/s]

Batch 800: Recon Loss: 69433.0391, Adv Loss: 10.9384, Fair Loss: 1.2621


Training Generator:  63%|██████▎   | 918/1458 [00:09<00:05, 92.49it/s]

Batch 900: Recon Loss: 65767.2031, Adv Loss: 10.7655, Fair Loss: 1.4167


Training Generator:  70%|██████▉   | 1018/1458 [00:11<00:04, 92.27it/s]

Batch 1000: Recon Loss: 66720.6562, Adv Loss: 10.8150, Fair Loss: 1.2192


Training Generator:  77%|███████▋  | 1118/1458 [00:12<00:03, 91.18it/s]

Batch 1100: Recon Loss: 64861.4336, Adv Loss: 10.8323, Fair Loss: 1.2914


Training Generator:  84%|████████▎ | 1218/1458 [00:13<00:02, 92.56it/s]

Batch 1200: Recon Loss: 66746.8359, Adv Loss: 11.2551, Fair Loss: 1.4162


Training Generator:  90%|█████████ | 1318/1458 [00:14<00:01, 93.27it/s]

Batch 1300: Recon Loss: 67601.0625, Adv Loss: 11.1067, Fair Loss: 1.3250


Training Generator:  97%|█████████▋| 1418/1458 [00:15<00:00, 92.00it/s]

Batch 1400: Recon Loss: 65751.6250, Adv Loss: 11.2311, Fair Loss: 1.3135


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 92.20it/s]


Epoch Summary - Attacker Loss: 0.2534, Fair Loss: 0.2909, Generator Loss: 66259.5821
Validation - Attacker Loss (orig): 0.2501
Validation - Attacker Loss (recon): 0.4570
Validation - Fair Loss: 0.2675
Validation - Debiasing Effect: 0.2069 (higher is better)
Epoch completed in 33.33 seconds
--------------------------------------------------------------------------------
Epoch 8/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 192.72it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 157.88it/s]
Training Generator:   8%|▊         | 118/1458 [00:01<00:14, 93.08it/s]

Batch 100: Recon Loss: 62425.7461, Adv Loss: 13.4719, Fair Loss: 1.3601


Training Generator:  15%|█▍        | 218/1458 [00:02<00:13, 93.47it/s]

Batch 200: Recon Loss: 63377.0000, Adv Loss: 13.3900, Fair Loss: 1.3100


Training Generator:  22%|██▏       | 318/1458 [00:03<00:12, 92.41it/s]

Batch 300: Recon Loss: 64224.5898, Adv Loss: 13.4708, Fair Loss: 1.2717


Training Generator:  29%|██▊       | 417/1458 [00:04<00:11, 90.99it/s]

Batch 400: Recon Loss: 65697.0703, Adv Loss: 13.2880, Fair Loss: 1.2930


Training Generator:  35%|███▌      | 517/1458 [00:05<00:10, 92.02it/s]

Batch 500: Recon Loss: 64654.4570, Adv Loss: 13.2979, Fair Loss: 1.2995


Training Generator:  42%|████▏     | 617/1458 [00:06<00:09, 91.98it/s]

Batch 600: Recon Loss: 67977.8047, Adv Loss: 13.5902, Fair Loss: 1.3266


Training Generator:  49%|████▉     | 717/1458 [00:07<00:08, 91.29it/s]

Batch 700: Recon Loss: 69125.2969, Adv Loss: 12.7847, Fair Loss: 1.3370


Training Generator:  56%|█████▌    | 817/1458 [00:08<00:06, 92.79it/s]

Batch 800: Recon Loss: 66145.3984, Adv Loss: 13.4832, Fair Loss: 1.3052


Training Generator:  63%|██████▎   | 917/1458 [00:09<00:05, 92.72it/s]

Batch 900: Recon Loss: 66797.6250, Adv Loss: 13.6762, Fair Loss: 1.3106


Training Generator:  70%|██████▉   | 1017/1458 [00:11<00:04, 92.12it/s]

Batch 1000: Recon Loss: 65295.0312, Adv Loss: 13.4917, Fair Loss: 1.1848


Training Generator:  77%|███████▋  | 1116/1458 [00:12<00:03, 89.17it/s]

Batch 1100: Recon Loss: 66770.7109, Adv Loss: 13.4016, Fair Loss: 1.2606


Training Generator:  83%|████████▎ | 1216/1458 [00:13<00:02, 92.22it/s]

Batch 1200: Recon Loss: 64618.4062, Adv Loss: 13.2958, Fair Loss: 1.3320


Training Generator:  90%|█████████ | 1315/1458 [00:14<00:01, 91.22it/s]

Batch 1300: Recon Loss: 68954.7344, Adv Loss: 12.9414, Fair Loss: 1.3817


Training Generator:  97%|█████████▋| 1415/1458 [00:15<00:00, 93.64it/s]

Batch 1400: Recon Loss: 67321.7656, Adv Loss: 13.7416, Fair Loss: 1.2816


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 91.78it/s]


Epoch Summary - Attacker Loss: 0.2526, Fair Loss: 0.2679, Generator Loss: 65923.6439
Validation - Attacker Loss (orig): 0.2494
Validation - Attacker Loss (recon): 0.4822
Validation - Fair Loss: 0.2591
Validation - Debiasing Effect: 0.2328 (higher is better)
Epoch completed in 33.45 seconds
--------------------------------------------------------------------------------
Epoch 9/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 190.84it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:08<00:00, 162.66it/s]
Training Generator:   8%|▊         | 117/1458 [00:01<00:14, 93.59it/s]

Batch 100: Recon Loss: 65217.0508, Adv Loss: 18.2726, Fair Loss: 1.2989


Training Generator:  15%|█▍        | 217/1458 [00:02<00:13, 92.25it/s]

Batch 200: Recon Loss: 65567.7812, Adv Loss: 18.4617, Fair Loss: 1.1482


Training Generator:  22%|██▏       | 317/1458 [00:03<00:12, 92.87it/s]

Batch 300: Recon Loss: 62964.2383, Adv Loss: 17.7824, Fair Loss: 1.3719


Training Generator:  29%|██▊       | 417/1458 [00:04<00:11, 92.15it/s]

Batch 400: Recon Loss: 66232.8984, Adv Loss: 18.0172, Fair Loss: 1.2823


Training Generator:  35%|███▌      | 517/1458 [00:05<00:10, 91.01it/s]

Batch 500: Recon Loss: 65595.8281, Adv Loss: 17.7747, Fair Loss: 1.2472


Training Generator:  42%|████▏     | 617/1458 [00:06<00:09, 91.72it/s]

Batch 600: Recon Loss: 63430.2891, Adv Loss: 17.6950, Fair Loss: 1.2993


Training Generator:  49%|████▉     | 717/1458 [00:07<00:08, 91.91it/s]

Batch 700: Recon Loss: 62495.3516, Adv Loss: 18.1965, Fair Loss: 1.3789


Training Generator:  56%|█████▌    | 817/1458 [00:08<00:07, 91.37it/s]

Batch 800: Recon Loss: 67889.1406, Adv Loss: 17.9163, Fair Loss: 1.2435


Training Generator:  63%|██████▎   | 917/1458 [00:10<00:05, 91.69it/s]

Batch 900: Recon Loss: 63628.1758, Adv Loss: 18.0352, Fair Loss: 1.2697


Training Generator:  70%|██████▉   | 1015/1458 [00:11<00:04, 89.52it/s]

Batch 1000: Recon Loss: 62274.4805, Adv Loss: 17.8474, Fair Loss: 1.2789


Training Generator:  76%|███████▋  | 1115/1458 [00:12<00:03, 91.86it/s]

Batch 1100: Recon Loss: 63861.0078, Adv Loss: 18.0368, Fair Loss: 1.2313


Training Generator:  83%|████████▎ | 1215/1458 [00:13<00:02, 93.48it/s]

Batch 1200: Recon Loss: 65081.3398, Adv Loss: 18.1239, Fair Loss: 1.3163


Training Generator:  90%|█████████ | 1315/1458 [00:14<00:01, 93.04it/s]

Batch 1300: Recon Loss: 67922.6797, Adv Loss: 17.4993, Fair Loss: 1.2194


Training Generator:  97%|█████████▋| 1415/1458 [00:15<00:00, 92.62it/s]

Batch 1400: Recon Loss: 65763.0234, Adv Loss: 17.7400, Fair Loss: 1.2111


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 91.49it/s]


Epoch Summary - Attacker Loss: 0.2520, Fair Loss: 0.2634, Generator Loss: 65791.9047
Validation - Attacker Loss (orig): 0.2493
Validation - Attacker Loss (recon): 0.5725
Validation - Fair Loss: 0.2572
Validation - Debiasing Effect: 0.3231 (higher is better)
Epoch completed in 33.32 seconds
--------------------------------------------------------------------------------
Epoch 10/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 187.44it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 160.75it/s]
Training Generator:   7%|▋         | 109/1458 [00:01<00:14, 93.29it/s]

Batch 100: Recon Loss: 65565.1172, Adv Loss: 15.9999, Fair Loss: 1.2532


Training Generator:  15%|█▌        | 219/1458 [00:02<00:13, 94.64it/s]

Batch 200: Recon Loss: 62496.9062, Adv Loss: 16.4214, Fair Loss: 1.2686


Training Generator:  21%|██        | 309/1458 [00:03<00:12, 93.89it/s]

Batch 300: Recon Loss: 67202.7188, Adv Loss: 16.1303, Fair Loss: 1.3953


Training Generator:  28%|██▊       | 409/1458 [00:04<00:11, 92.74it/s]

Batch 400: Recon Loss: 64985.7305, Adv Loss: 15.9231, Fair Loss: 1.3096


Training Generator:  35%|███▍      | 509/1458 [00:05<00:10, 92.20it/s]

Batch 500: Recon Loss: 66611.2812, Adv Loss: 16.0869, Fair Loss: 1.0524


Training Generator:  42%|████▏     | 609/1458 [00:06<00:09, 91.84it/s]

Batch 600: Recon Loss: 64716.4570, Adv Loss: 15.7562, Fair Loss: 1.2096


Training Generator:  49%|████▊     | 709/1458 [00:07<00:08, 91.10it/s]

Batch 700: Recon Loss: 66041.5938, Adv Loss: 14.9416, Fair Loss: 1.3685


Training Generator:  55%|█████▌    | 809/1458 [00:08<00:07, 91.68it/s]

Batch 800: Recon Loss: 63383.7461, Adv Loss: 16.0273, Fair Loss: 1.2256


Training Generator:  63%|██████▎   | 919/1458 [00:09<00:05, 94.05it/s]

Batch 900: Recon Loss: 65866.9531, Adv Loss: 16.1681, Fair Loss: 1.2075


Training Generator:  69%|██████▉   | 1009/1458 [00:10<00:04, 93.13it/s]

Batch 1000: Recon Loss: 64334.0938, Adv Loss: 15.9939, Fair Loss: 1.3700


Training Generator:  76%|███████▌  | 1109/1458 [00:12<00:03, 90.44it/s]

Batch 1100: Recon Loss: 64555.2695, Adv Loss: 17.0725, Fair Loss: 1.1627


Training Generator:  83%|████████▎ | 1209/1458 [00:13<00:02, 91.36it/s]

Batch 1200: Recon Loss: 63384.0625, Adv Loss: 16.0998, Fair Loss: 1.1922


Training Generator:  90%|█████████ | 1319/1458 [00:14<00:01, 93.94it/s]

Batch 1300: Recon Loss: 65837.8047, Adv Loss: 14.6357, Fair Loss: 1.2556


Training Generator:  97%|█████████▋| 1409/1458 [00:15<00:00, 93.06it/s]

Batch 1400: Recon Loss: 65627.2578, Adv Loss: 16.3618, Fair Loss: 1.2766


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 92.16it/s]


Epoch Summary - Attacker Loss: 0.2513, Fair Loss: 0.2612, Generator Loss: 65714.1139
Validation - Attacker Loss (orig): 0.2496
Validation - Attacker Loss (recon): 0.5679
Validation - Fair Loss: 0.2560
Validation - Debiasing Effect: 0.3182 (higher is better)
Epoch completed in 33.44 seconds
--------------------------------------------------------------------------------
Epoch 11/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 191.19it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:08<00:00, 162.79it/s]
Training Generator:   7%|▋         | 109/1458 [00:01<00:14, 93.16it/s]

Batch 100: Recon Loss: 66090.5781, Adv Loss: 16.4364, Fair Loss: 1.2953


Training Generator:  14%|█▍        | 209/1458 [00:02<00:13, 93.36it/s]

Batch 200: Recon Loss: 67186.8047, Adv Loss: 15.9311, Fair Loss: 1.2272


Training Generator:  21%|██        | 309/1458 [00:03<00:12, 91.34it/s]

Batch 300: Recon Loss: 68384.2734, Adv Loss: 16.0917, Fair Loss: 1.2669


Training Generator:  28%|██▊       | 409/1458 [00:04<00:11, 91.26it/s]

Batch 400: Recon Loss: 63328.9023, Adv Loss: 15.8820, Fair Loss: 1.2330


Training Generator:  36%|███▌      | 519/1458 [00:05<00:09, 93.92it/s]

Batch 500: Recon Loss: 65872.7266, Adv Loss: 15.8989, Fair Loss: 1.1585


Training Generator:  42%|████▏     | 619/1458 [00:06<00:08, 93.76it/s]

Batch 600: Recon Loss: 62902.9492, Adv Loss: 16.0047, Fair Loss: 1.3261


Training Generator:  49%|████▊     | 709/1458 [00:07<00:07, 93.86it/s]

Batch 700: Recon Loss: 65585.1328, Adv Loss: 16.7924, Fair Loss: 1.4406


Training Generator:  55%|█████▌    | 809/1458 [00:08<00:06, 92.88it/s]

Batch 800: Recon Loss: 62348.7773, Adv Loss: 15.9303, Fair Loss: 1.1456


Training Generator:  62%|██████▏   | 909/1458 [00:09<00:05, 91.72it/s]

Batch 900: Recon Loss: 62017.3320, Adv Loss: 16.3441, Fair Loss: 1.2288


Training Generator:  69%|██████▉   | 1009/1458 [00:10<00:04, 93.12it/s]

Batch 1000: Recon Loss: 65502.6211, Adv Loss: 16.2348, Fair Loss: 1.3224


Training Generator:  76%|███████▌  | 1109/1458 [00:11<00:03, 93.05it/s]

Batch 1100: Recon Loss: 64515.6758, Adv Loss: 16.3821, Fair Loss: 1.4104


Training Generator:  84%|████████▎ | 1219/1458 [00:13<00:02, 93.21it/s]

Batch 1200: Recon Loss: 63204.5117, Adv Loss: 15.7456, Fair Loss: 1.2104


Training Generator:  90%|█████████ | 1318/1458 [00:14<00:01, 88.88it/s]

Batch 1300: Recon Loss: 63718.2266, Adv Loss: 16.1565, Fair Loss: 1.2082


Training Generator:  97%|█████████▋| 1413/1458 [00:15<00:00, 89.85it/s]

Batch 1400: Recon Loss: 65344.7773, Adv Loss: 16.3207, Fair Loss: 1.3455


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 91.80it/s]


Epoch Summary - Attacker Loss: 0.2505, Fair Loss: 0.2597, Generator Loss: 65663.8118
Validation - Attacker Loss (orig): 0.2485
Validation - Attacker Loss (recon): 0.5494
Validation - Fair Loss: 0.2561
Validation - Debiasing Effect: 0.3008 (higher is better)
Epoch completed in 33.26 seconds
--------------------------------------------------------------------------------
Epoch 12/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:08<00:00, 180.23it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 160.46it/s]
Training Generator:   7%|▋         | 109/1458 [00:01<00:14, 93.52it/s]

Batch 100: Recon Loss: 66105.3594, Adv Loss: 16.3066, Fair Loss: 1.1983


Training Generator:  15%|█▌        | 219/1458 [00:02<00:13, 94.88it/s]

Batch 200: Recon Loss: 64311.6641, Adv Loss: 16.3147, Fair Loss: 1.2892


Training Generator:  22%|██▏       | 319/1458 [00:03<00:12, 93.71it/s]

Batch 300: Recon Loss: 65781.9141, Adv Loss: 17.1278, Fair Loss: 1.3543


Training Generator:  28%|██▊       | 409/1458 [00:04<00:12, 87.22it/s]

Batch 400: Recon Loss: 64434.7773, Adv Loss: 16.3843, Fair Loss: 1.2724


Training Generator:  35%|███▍      | 509/1458 [00:05<00:10, 91.09it/s]

Batch 500: Recon Loss: 63967.9258, Adv Loss: 16.7239, Fair Loss: 1.3777


Training Generator:  42%|████▏     | 609/1458 [00:06<00:09, 91.97it/s]

Batch 600: Recon Loss: 66696.8359, Adv Loss: 16.9817, Fair Loss: 1.3722


Training Generator:  49%|████▊     | 709/1458 [00:07<00:08, 91.25it/s]

Batch 700: Recon Loss: 66921.4297, Adv Loss: 16.4874, Fair Loss: 1.2885


Training Generator:  55%|█████▌    | 809/1458 [00:08<00:07, 91.55it/s]

Batch 800: Recon Loss: 66273.3203, Adv Loss: 15.3262, Fair Loss: 1.2333


Training Generator:  62%|██████▏   | 909/1458 [00:09<00:05, 92.05it/s]

Batch 900: Recon Loss: 66053.5000, Adv Loss: 16.9420, Fair Loss: 1.1613


Training Generator:  70%|██████▉   | 1018/1458 [00:11<00:04, 91.01it/s]

Batch 1000: Recon Loss: 69271.1484, Adv Loss: 16.1389, Fair Loss: 1.2200


Training Generator:  77%|███████▋  | 1118/1458 [00:12<00:03, 91.99it/s]

Batch 1100: Recon Loss: 66106.2422, Adv Loss: 15.1955, Fair Loss: 1.3081


Training Generator:  84%|████████▎ | 1218/1458 [00:13<00:02, 92.89it/s]

Batch 1200: Recon Loss: 66662.7344, Adv Loss: 16.0926, Fair Loss: 1.4381


Training Generator:  90%|█████████ | 1318/1458 [00:14<00:01, 93.12it/s]

Batch 1300: Recon Loss: 65903.1719, Adv Loss: 16.1563, Fair Loss: 1.2395


Training Generator:  97%|█████████▋| 1418/1458 [00:15<00:00, 92.62it/s]

Batch 1400: Recon Loss: 67058.3203, Adv Loss: 16.5429, Fair Loss: 1.1851


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 92.24it/s]


Epoch Summary - Attacker Loss: 0.2500, Fair Loss: 0.2584, Generator Loss: 65628.8958
Validation - Attacker Loss (orig): 0.2486
Validation - Attacker Loss (recon): 0.5398
Validation - Fair Loss: 0.2548
Validation - Debiasing Effect: 0.2912 (higher is better)
Epoch completed in 33.75 seconds
--------------------------------------------------------------------------------
Epoch 13/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 187.67it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:08<00:00, 162.35it/s]
Training Generator:   8%|▊         | 119/1458 [00:01<00:14, 95.17it/s]

Batch 100: Recon Loss: 65064.4336, Adv Loss: 15.1165, Fair Loss: 1.3130


Training Generator:  15%|█▌        | 219/1458 [00:02<00:13, 94.78it/s]

Batch 200: Recon Loss: 67936.8047, Adv Loss: 15.3153, Fair Loss: 1.2340


Training Generator:  22%|██▏       | 319/1458 [00:03<00:12, 94.89it/s]

Batch 300: Recon Loss: 65317.4766, Adv Loss: 15.0293, Fair Loss: 1.2903


Training Generator:  29%|██▊       | 419/1458 [00:04<00:10, 95.64it/s]

Batch 400: Recon Loss: 64564.5703, Adv Loss: 16.3672, Fair Loss: 1.2361


Training Generator:  36%|███▌      | 519/1458 [00:05<00:09, 95.99it/s]

Batch 500: Recon Loss: 68569.8047, Adv Loss: 15.7300, Fair Loss: 1.4137


Training Generator:  42%|████▏     | 609/1458 [00:06<00:09, 90.29it/s]

Batch 600: Recon Loss: 65286.1445, Adv Loss: 16.4004, Fair Loss: 1.0948


Training Generator:  49%|████▉     | 719/1458 [00:07<00:07, 94.62it/s]

Batch 700: Recon Loss: 64528.9023, Adv Loss: 16.3316, Fair Loss: 1.2162


Training Generator:  56%|█████▌    | 819/1458 [00:08<00:06, 95.01it/s]

Batch 800: Recon Loss: 65295.3867, Adv Loss: 15.0338, Fair Loss: 1.2189


Training Generator:  62%|██████▏   | 909/1458 [00:09<00:05, 93.79it/s]

Batch 900: Recon Loss: 63261.7891, Adv Loss: 15.6373, Fair Loss: 1.2763


Training Generator:  70%|██████▉   | 1019/1458 [00:10<00:04, 96.61it/s]

Batch 1000: Recon Loss: 62759.0625, Adv Loss: 16.6053, Fair Loss: 1.3040


Training Generator:  77%|███████▋  | 1119/1458 [00:11<00:03, 96.60it/s]

Batch 1100: Recon Loss: 65996.4609, Adv Loss: 16.6486, Fair Loss: 1.2018


Training Generator:  84%|████████▎ | 1219/1458 [00:12<00:02, 96.35it/s]

Batch 1200: Recon Loss: 63811.6836, Adv Loss: 15.8544, Fair Loss: 1.3796


Training Generator:  90%|█████████ | 1319/1458 [00:13<00:01, 95.04it/s]

Batch 1300: Recon Loss: 67077.5000, Adv Loss: 15.8944, Fair Loss: 1.2929


Training Generator:  97%|█████████▋| 1419/1458 [00:14<00:00, 95.01it/s]

Batch 1400: Recon Loss: 65725.2812, Adv Loss: 16.1486, Fair Loss: 1.1881


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 94.83it/s]


Epoch Summary - Attacker Loss: 0.2494, Fair Loss: 0.2575, Generator Loss: 65599.0076
Validation - Attacker Loss (orig): 0.2493
Validation - Attacker Loss (recon): 0.5418
Validation - Fair Loss: 0.2544
Validation - Debiasing Effect: 0.2925 (higher is better)
Epoch completed in 32.93 seconds
--------------------------------------------------------------------------------
Epoch 14/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 193.00it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:09<00:00, 161.03it/s]
Training Generator:   7%|▋         | 109/1458 [00:01<00:14, 93.28it/s]

Batch 100: Recon Loss: 66772.9062, Adv Loss: 18.1086, Fair Loss: 1.2181


Training Generator:  14%|█▍        | 209/1458 [00:02<00:13, 91.78it/s]

Batch 200: Recon Loss: 61985.4805, Adv Loss: 17.2846, Fair Loss: 1.3216


Training Generator:  22%|██▏       | 319/1458 [00:03<00:12, 92.80it/s]

Batch 300: Recon Loss: 67227.4531, Adv Loss: 17.7201, Fair Loss: 1.2801


Training Generator:  29%|██▊       | 418/1458 [00:04<00:11, 92.42it/s]

Batch 400: Recon Loss: 63188.5625, Adv Loss: 17.7044, Fair Loss: 1.2844


Training Generator:  36%|███▌      | 518/1458 [00:05<00:10, 92.05it/s]

Batch 500: Recon Loss: 66621.5156, Adv Loss: 18.2846, Fair Loss: 1.1601


Training Generator:  42%|████▏     | 618/1458 [00:06<00:09, 93.05it/s]

Batch 600: Recon Loss: 66976.6406, Adv Loss: 17.8235, Fair Loss: 1.2667


Training Generator:  49%|████▉     | 718/1458 [00:07<00:07, 92.52it/s]

Batch 700: Recon Loss: 65726.8672, Adv Loss: 18.6174, Fair Loss: 1.1354


Training Generator:  56%|█████▌    | 818/1458 [00:08<00:06, 92.79it/s]

Batch 800: Recon Loss: 62511.8867, Adv Loss: 18.1183, Fair Loss: 1.2832


Training Generator:  63%|██████▎   | 918/1458 [00:09<00:05, 93.22it/s]

Batch 900: Recon Loss: 65018.0430, Adv Loss: 17.7450, Fair Loss: 1.3873


Training Generator:  70%|██████▉   | 1018/1458 [00:11<00:04, 92.59it/s]

Batch 1000: Recon Loss: 65627.5469, Adv Loss: 17.5426, Fair Loss: 1.3611


Training Generator:  77%|███████▋  | 1118/1458 [00:12<00:03, 93.52it/s]

Batch 1100: Recon Loss: 63103.5195, Adv Loss: 17.8242, Fair Loss: 1.1402


Training Generator:  84%|████████▎ | 1218/1458 [00:13<00:02, 93.53it/s]

Batch 1200: Recon Loss: 67780.5234, Adv Loss: 16.2769, Fair Loss: 1.2596


Training Generator:  90%|█████████ | 1318/1458 [00:14<00:01, 91.59it/s]

Batch 1300: Recon Loss: 67717.2109, Adv Loss: 18.5727, Fair Loss: 1.2741


Training Generator:  97%|█████████▋| 1413/1458 [00:15<00:00, 84.52it/s]

Batch 1400: Recon Loss: 63505.4453, Adv Loss: 18.1275, Fair Loss: 1.2254


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 91.49it/s]


Epoch Summary - Attacker Loss: 0.2489, Fair Loss: 0.2567, Generator Loss: 65578.1422
Validation - Attacker Loss (orig): 0.2481
Validation - Attacker Loss (recon): 0.5276
Validation - Fair Loss: 0.2538
Validation - Debiasing Effect: 0.2795 (higher is better)
Epoch completed in 33.34 seconds
--------------------------------------------------------------------------------
Epoch 15/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:07<00:00, 193.87it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:08<00:00, 169.40it/s]
Training Generator:   8%|▊         | 115/1458 [00:01<00:15, 87.27it/s]

Batch 100: Recon Loss: 67660.1406, Adv Loss: 18.2774, Fair Loss: 1.2973


Training Generator:  14%|█▍        | 211/1458 [00:02<00:13, 90.48it/s]

Batch 200: Recon Loss: 64317.7617, Adv Loss: 18.5821, Fair Loss: 1.2407


Training Generator:  21%|██▏       | 311/1458 [00:03<00:12, 90.13it/s]

Batch 300: Recon Loss: 67752.2344, Adv Loss: 18.6306, Fair Loss: 1.2541


Training Generator:  29%|██▊       | 417/1458 [00:04<00:11, 91.37it/s]

Batch 400: Recon Loss: 66453.0391, Adv Loss: 18.8147, Fair Loss: 1.2423


Training Generator:  35%|███▌      | 517/1458 [00:05<00:10, 91.51it/s]

Batch 500: Recon Loss: 66382.5234, Adv Loss: 17.3697, Fair Loss: 1.2678


Training Generator:  42%|████▏     | 617/1458 [00:06<00:09, 90.67it/s]

Batch 600: Recon Loss: 66162.7422, Adv Loss: 18.4331, Fair Loss: 1.2278


Training Generator:  49%|████▉     | 717/1458 [00:07<00:08, 90.76it/s]

Batch 700: Recon Loss: 61003.4023, Adv Loss: 18.6045, Fair Loss: 1.1322


Training Generator:  56%|█████▌    | 817/1458 [00:09<00:07, 90.80it/s]

Batch 800: Recon Loss: 68553.9375, Adv Loss: 18.0039, Fair Loss: 1.2360


Training Generator:  62%|██████▏   | 909/1458 [00:10<00:06, 87.38it/s]

Batch 900: Recon Loss: 64554.8516, Adv Loss: 17.2930, Fair Loss: 1.2769


Training Generator:  69%|██████▉   | 1013/1458 [00:11<00:04, 102.77it/s]

Batch 1000: Recon Loss: 64281.9492, Adv Loss: 18.1439, Fair Loss: 1.1675


Training Generator:  77%|███████▋  | 1121/1458 [00:12<00:02, 117.88it/s]

Batch 1100: Recon Loss: 63378.4961, Adv Loss: 17.8597, Fair Loss: 1.2392


Training Generator:  83%|████████▎ | 1217/1458 [00:12<00:02, 117.69it/s]

Batch 1200: Recon Loss: 65536.5625, Adv Loss: 19.3103, Fair Loss: 1.2671


Training Generator:  90%|█████████ | 1314/1458 [00:13<00:01, 105.17it/s]

Batch 1300: Recon Loss: 62105.2891, Adv Loss: 17.5603, Fair Loss: 1.4170


Training Generator:  97%|█████████▋| 1413/1458 [00:14<00:00, 99.09it/s] 

Batch 1400: Recon Loss: 67250.0312, Adv Loss: 18.8470, Fair Loss: 1.2437


Training Generator: 100%|██████████| 1458/1458 [00:15<00:00, 95.67it/s] 


Epoch Summary - Attacker Loss: 0.2484, Fair Loss: 0.2562, Generator Loss: 65557.5817
Validation - Attacker Loss (orig): 0.2478
Validation - Attacker Loss (recon): 0.5316
Validation - Fair Loss: 0.2538
Validation - Debiasing Effect: 0.2838 (higher is better)
Epoch completed in 32.04 seconds
--------------------------------------------------------------------------------
Epoch 16/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 233.38it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 183.69it/s]
Training Generator:   8%|▊         | 119/1458 [00:01<00:11, 115.85it/s]

Batch 100: Recon Loss: 65046.7305, Adv Loss: 18.4070, Fair Loss: 1.1542


Training Generator:  15%|█▍        | 217/1458 [00:01<00:10, 118.95it/s]

Batch 200: Recon Loss: 68316.8516, Adv Loss: 18.5711, Fair Loss: 1.1733


Training Generator:  22%|██▏       | 320/1458 [00:02<00:09, 119.42it/s]

Batch 300: Recon Loss: 64410.9492, Adv Loss: 18.8841, Fair Loss: 1.2345


Training Generator:  29%|██▉       | 422/1458 [00:03<00:08, 120.60it/s]

Batch 400: Recon Loss: 64452.7188, Adv Loss: 19.6574, Fair Loss: 1.2606


Training Generator:  35%|███▌      | 513/1458 [00:04<00:07, 121.14it/s]

Batch 500: Recon Loss: 65349.6758, Adv Loss: 18.8675, Fair Loss: 1.1437


Training Generator:  42%|████▏     | 615/1458 [00:05<00:07, 120.27it/s]

Batch 600: Recon Loss: 68703.0391, Adv Loss: 18.2583, Fair Loss: 1.2190


Training Generator:  49%|████▉     | 714/1458 [00:06<00:06, 118.41it/s]

Batch 700: Recon Loss: 62985.0938, Adv Loss: 19.0355, Fair Loss: 1.1584


Training Generator:  56%|█████▋    | 823/1458 [00:06<00:05, 118.22it/s]

Batch 800: Recon Loss: 63179.6328, Adv Loss: 19.6365, Fair Loss: 1.2524


Training Generator:  63%|██████▎   | 919/1458 [00:07<00:04, 117.83it/s]

Batch 900: Recon Loss: 66535.4766, Adv Loss: 19.2167, Fair Loss: 1.2794


Training Generator:  70%|██████▉   | 1015/1458 [00:08<00:03, 116.73it/s]

Batch 1000: Recon Loss: 67956.1328, Adv Loss: 18.4172, Fair Loss: 1.3287


Training Generator:  77%|███████▋  | 1123/1458 [00:09<00:02, 116.86it/s]

Batch 1100: Recon Loss: 66422.5469, Adv Loss: 18.3080, Fair Loss: 1.1471


Training Generator:  84%|████████▎ | 1219/1458 [00:10<00:02, 118.51it/s]

Batch 1200: Recon Loss: 66529.4531, Adv Loss: 18.7045, Fair Loss: 1.2409


Training Generator:  90%|█████████ | 1315/1458 [00:11<00:01, 113.46it/s]

Batch 1300: Recon Loss: 68683.8438, Adv Loss: 19.0404, Fair Loss: 1.3254


Training Generator:  98%|█████████▊| 1423/1458 [00:12<00:00, 118.32it/s]

Batch 1400: Recon Loss: 66715.5859, Adv Loss: 19.3590, Fair Loss: 1.1939


Training Generator: 100%|██████████| 1458/1458 [00:12<00:00, 117.85it/s]


Epoch Summary - Attacker Loss: 0.2482, Fair Loss: 0.2554, Generator Loss: 65541.7538
Validation - Attacker Loss (orig): 0.2482
Validation - Attacker Loss (recon): 0.5416
Validation - Fair Loss: 0.2536
Validation - Debiasing Effect: 0.2934 (higher is better)
Epoch completed in 27.21 seconds
--------------------------------------------------------------------------------
Epoch 17/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 230.15it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 198.62it/s]
Training Generator:   8%|▊         | 120/1458 [00:01<00:11, 118.63it/s]

Batch 100: Recon Loss: 64623.5586, Adv Loss: 19.3641, Fair Loss: 1.1591


Training Generator:  15%|█▌        | 219/1458 [00:01<00:10, 120.70it/s]

Batch 200: Recon Loss: 68926.8047, Adv Loss: 18.7996, Fair Loss: 1.2607


Training Generator:  22%|██▏       | 323/1458 [00:02<00:09, 120.42it/s]

Batch 300: Recon Loss: 68553.1172, Adv Loss: 20.0304, Fair Loss: 1.1995


Training Generator:  28%|██▊       | 412/1458 [00:03<00:08, 119.32it/s]

Batch 400: Recon Loss: 65498.7695, Adv Loss: 19.0436, Fair Loss: 1.3724


Training Generator:  35%|███▌      | 512/1458 [00:04<00:08, 113.40it/s]

Batch 500: Recon Loss: 64375.3203, Adv Loss: 18.8302, Fair Loss: 1.2267


Training Generator:  42%|████▏     | 615/1458 [00:05<00:06, 121.12it/s]

Batch 600: Recon Loss: 64099.8711, Adv Loss: 19.5898, Fair Loss: 1.2558


Training Generator:  49%|████▉     | 718/1458 [00:06<00:06, 119.97it/s]

Batch 700: Recon Loss: 65206.7070, Adv Loss: 18.6241, Fair Loss: 1.4564


Training Generator:  56%|█████▋    | 822/1458 [00:06<00:05, 120.45it/s]

Batch 800: Recon Loss: 64127.3203, Adv Loss: 18.6975, Fair Loss: 1.2851


Training Generator:  63%|██████▎   | 913/1458 [00:07<00:04, 121.23it/s]

Batch 900: Recon Loss: 64846.8711, Adv Loss: 18.9649, Fair Loss: 1.2448


Training Generator:  70%|██████▉   | 1014/1458 [00:08<00:04, 102.80it/s]

Batch 1000: Recon Loss: 66614.1328, Adv Loss: 18.7187, Fair Loss: 1.1993


Training Generator:  76%|███████▋  | 1115/1458 [00:09<00:03, 89.07it/s] 

Batch 1100: Recon Loss: 66010.1484, Adv Loss: 19.3052, Fair Loss: 1.2278


Training Generator:  83%|████████▎ | 1211/1458 [00:10<00:02, 94.32it/s]

Batch 1200: Recon Loss: 63242.8945, Adv Loss: 18.5947, Fair Loss: 1.2964


Training Generator:  90%|████████▉ | 1312/1458 [00:11<00:01, 112.28it/s]

Batch 1300: Recon Loss: 64875.4453, Adv Loss: 19.5515, Fair Loss: 1.2934


Training Generator:  97%|█████████▋| 1411/1458 [00:12<00:00, 116.69it/s]

Batch 1400: Recon Loss: 65399.2461, Adv Loss: 19.1334, Fair Loss: 1.2186


Training Generator: 100%|██████████| 1458/1458 [00:12<00:00, 112.74it/s]


Epoch Summary - Attacker Loss: 0.2479, Fair Loss: 0.2549, Generator Loss: 65527.2582
Validation - Attacker Loss (orig): 0.2479
Validation - Attacker Loss (recon): 0.5438
Validation - Fair Loss: 0.2541
Validation - Debiasing Effect: 0.2959 (higher is better)
Epoch completed in 27.27 seconds
--------------------------------------------------------------------------------
Epoch 18/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 233.05it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 189.06it/s]
Training Generator:   8%|▊         | 113/1458 [00:01<00:12, 105.58it/s]

Batch 100: Recon Loss: 68412.1172, Adv Loss: 18.6440, Fair Loss: 1.1376


Training Generator:  15%|█▍        | 215/1458 [00:02<00:10, 120.08it/s]

Batch 200: Recon Loss: 67861.2734, Adv Loss: 19.0163, Fair Loss: 1.2665


Training Generator:  22%|██▏       | 315/1458 [00:02<00:09, 118.32it/s]

Batch 300: Recon Loss: 62917.6133, Adv Loss: 18.1383, Fair Loss: 1.2748


Training Generator:  29%|██▉       | 424/1458 [00:03<00:08, 119.45it/s]

Batch 400: Recon Loss: 65209.8086, Adv Loss: 19.9982, Fair Loss: 1.1910


Training Generator:  35%|███▌      | 514/1458 [00:04<00:07, 120.03it/s]

Batch 500: Recon Loss: 67076.7578, Adv Loss: 18.1915, Fair Loss: 1.2161


Training Generator:  42%|████▏     | 616/1458 [00:05<00:07, 116.83it/s]

Batch 600: Recon Loss: 66314.1406, Adv Loss: 18.7827, Fair Loss: 1.2562


Training Generator:  49%|████▉     | 719/1458 [00:06<00:06, 120.10it/s]

Batch 700: Recon Loss: 66254.9297, Adv Loss: 19.0270, Fair Loss: 1.2287


Training Generator:  56%|█████▌    | 818/1458 [00:07<00:05, 118.26it/s]

Batch 800: Recon Loss: 66304.1797, Adv Loss: 18.5170, Fair Loss: 1.4707


Training Generator:  63%|██████▎   | 920/1458 [00:07<00:04, 120.03it/s]

Batch 900: Recon Loss: 67210.0234, Adv Loss: 18.4021, Fair Loss: 1.3179


Training Generator:  70%|██████▉   | 1017/1458 [00:08<00:03, 117.97it/s]

Batch 1000: Recon Loss: 64285.9961, Adv Loss: 19.3986, Fair Loss: 1.1909


Training Generator:  76%|███████▋  | 1113/1458 [00:09<00:02, 119.03it/s]

Batch 1100: Recon Loss: 66718.7812, Adv Loss: 18.2339, Fair Loss: 1.2769


Training Generator:  83%|████████▎ | 1214/1458 [00:10<00:02, 120.66it/s]

Batch 1200: Recon Loss: 66026.5547, Adv Loss: 19.4802, Fair Loss: 1.2918


Training Generator:  90%|█████████ | 1316/1458 [00:11<00:01, 119.29it/s]

Batch 1300: Recon Loss: 66540.9062, Adv Loss: 18.6221, Fair Loss: 1.3180


Training Generator:  97%|█████████▋| 1419/1458 [00:12<00:00, 121.44it/s]

Batch 1400: Recon Loss: 66925.7422, Adv Loss: 18.8935, Fair Loss: 1.2633


Training Generator: 100%|██████████| 1458/1458 [00:12<00:00, 117.04it/s]


Epoch Summary - Attacker Loss: 0.2475, Fair Loss: 0.2544, Generator Loss: 65514.1437
Validation - Attacker Loss (orig): 0.2476
Validation - Attacker Loss (recon): 0.5364
Validation - Fair Loss: 0.2533
Validation - Debiasing Effect: 0.2888 (higher is better)
Epoch completed in 27.07 seconds
--------------------------------------------------------------------------------
Epoch 19/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 218.74it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 198.14it/s]
Training Generator:   8%|▊         | 121/1458 [00:01<00:11, 119.09it/s]

Batch 100: Recon Loss: 67945.2969, Adv Loss: 19.1324, Fair Loss: 1.2770


Training Generator:  15%|█▌        | 223/1458 [00:01<00:10, 120.21it/s]

Batch 200: Recon Loss: 64673.9062, Adv Loss: 18.7595, Fair Loss: 1.2623


Training Generator:  22%|██▏       | 321/1458 [00:02<00:09, 118.39it/s]

Batch 300: Recon Loss: 63610.2930, Adv Loss: 17.9384, Fair Loss: 1.2088


Training Generator:  29%|██▉       | 424/1458 [00:03<00:08, 120.70it/s]

Batch 400: Recon Loss: 66609.4375, Adv Loss: 17.3382, Fair Loss: 1.2057


Training Generator:  35%|███▌      | 514/1458 [00:04<00:07, 119.91it/s]

Batch 500: Recon Loss: 65964.8516, Adv Loss: 18.5282, Fair Loss: 1.2137


Training Generator:  42%|████▏     | 616/1458 [00:05<00:07, 119.88it/s]

Batch 600: Recon Loss: 65938.2109, Adv Loss: 18.1874, Fair Loss: 1.2539


Training Generator:  49%|████▉     | 719/1458 [00:06<00:06, 121.73it/s]

Batch 700: Recon Loss: 66548.3203, Adv Loss: 19.5034, Fair Loss: 1.2303


Training Generator:  56%|█████▋    | 823/1458 [00:06<00:05, 121.83it/s]

Batch 800: Recon Loss: 65792.7734, Adv Loss: 18.2196, Fair Loss: 1.1844


Training Generator:  63%|██████▎   | 914/1458 [00:07<00:04, 118.42it/s]

Batch 900: Recon Loss: 65873.6094, Adv Loss: 18.7520, Fair Loss: 1.2164


Training Generator:  69%|██████▉   | 1010/1458 [00:08<00:04, 91.20it/s]

Batch 1000: Recon Loss: 67751.6562, Adv Loss: 18.6279, Fair Loss: 1.2563


Training Generator:  77%|███████▋  | 1118/1458 [00:09<00:03, 89.61it/s]

Batch 1100: Recon Loss: 64558.3438, Adv Loss: 19.6836, Fair Loss: 1.3071


Training Generator:  84%|████████▎ | 1218/1458 [00:10<00:02, 94.34it/s]

Batch 1200: Recon Loss: 67972.9609, Adv Loss: 18.8111, Fair Loss: 1.2073


Training Generator:  90%|█████████ | 1315/1458 [00:11<00:01, 117.80it/s]

Batch 1300: Recon Loss: 66557.9297, Adv Loss: 18.8103, Fair Loss: 1.1538


Training Generator:  97%|█████████▋| 1418/1458 [00:12<00:00, 121.22it/s]

Batch 1400: Recon Loss: 69191.3828, Adv Loss: 18.4793, Fair Loss: 1.1961


Training Generator: 100%|██████████| 1458/1458 [00:13<00:00, 112.11it/s]


Epoch Summary - Attacker Loss: 0.2471, Fair Loss: 0.2539, Generator Loss: 65502.8615
Validation - Attacker Loss (orig): 0.2482
Validation - Attacker Loss (recon): 0.5340
Validation - Fair Loss: 0.2529
Validation - Debiasing Effect: 0.2858 (higher is better)
Epoch completed in 27.68 seconds
--------------------------------------------------------------------------------
Epoch 20/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 237.13it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 189.00it/s]
Training Generator:   8%|▊         | 114/1458 [00:01<00:15, 86.01it/s]

Batch 100: Recon Loss: 68029.1328, Adv Loss: 18.1742, Fair Loss: 1.2964


Training Generator:  15%|█▍        | 217/1458 [00:02<00:13, 94.08it/s]

Batch 200: Recon Loss: 65509.1055, Adv Loss: 18.8091, Fair Loss: 1.3460


Training Generator:  22%|██▏       | 316/1458 [00:03<00:09, 118.03it/s]

Batch 300: Recon Loss: 65333.8203, Adv Loss: 17.6835, Fair Loss: 1.2754


Training Generator:  28%|██▊       | 415/1458 [00:04<00:08, 119.43it/s]

Batch 400: Recon Loss: 66545.6953, Adv Loss: 17.3309, Fair Loss: 1.3333


Training Generator:  35%|███▌      | 516/1458 [00:05<00:07, 118.65it/s]

Batch 500: Recon Loss: 65128.7891, Adv Loss: 17.6890, Fair Loss: 1.0351


Training Generator:  42%|████▏     | 618/1458 [00:05<00:06, 120.07it/s]

Batch 600: Recon Loss: 63912.9023, Adv Loss: 17.4640, Fair Loss: 1.2705


Training Generator:  49%|████▉     | 718/1458 [00:06<00:06, 119.67it/s]

Batch 700: Recon Loss: 64417.4766, Adv Loss: 17.5971, Fair Loss: 1.2055


Training Generator:  56%|█████▌    | 820/1458 [00:07<00:05, 116.56it/s]

Batch 800: Recon Loss: 64144.3945, Adv Loss: 17.8761, Fair Loss: 1.3328


Training Generator:  63%|██████▎   | 923/1458 [00:08<00:04, 121.46it/s]

Batch 900: Recon Loss: 63687.9648, Adv Loss: 17.3071, Fair Loss: 1.1762


Training Generator:  70%|██████▉   | 1014/1458 [00:09<00:03, 120.61it/s]

Batch 1000: Recon Loss: 66194.4141, Adv Loss: 17.4134, Fair Loss: 1.3271


Training Generator:  76%|███████▋  | 1114/1458 [00:10<00:02, 116.70it/s]

Batch 1100: Recon Loss: 66323.7969, Adv Loss: 17.0246, Fair Loss: 1.1612


Training Generator:  83%|████████▎ | 1216/1458 [00:10<00:02, 120.73it/s]

Batch 1200: Recon Loss: 67598.2891, Adv Loss: 18.2742, Fair Loss: 1.2743


Training Generator:  90%|█████████ | 1319/1458 [00:11<00:01, 114.56it/s]

Batch 1300: Recon Loss: 69657.0391, Adv Loss: 19.0070, Fair Loss: 1.2658


Training Generator:  97%|█████████▋| 1418/1458 [00:12<00:00, 116.57it/s]

Batch 1400: Recon Loss: 65119.3828, Adv Loss: 17.7785, Fair Loss: 1.2814


Training Generator: 100%|██████████| 1458/1458 [00:13<00:00, 111.96it/s]


Epoch Summary - Attacker Loss: 0.2468, Fair Loss: 0.2535, Generator Loss: 65491.4496
Validation - Attacker Loss (orig): 0.2481
Validation - Attacker Loss (recon): 0.5412
Validation - Fair Loss: 0.2527
Validation - Debiasing Effect: 0.2931 (higher is better)
Epoch completed in 27.54 seconds
--------------------------------------------------------------------------------
Epoch 21/30 - Phase 2: Adversarial training


Training Attacker: 100%|██████████| 1458/1458 [00:06<00:00, 218.26it/s]
Training Fair Model: 100%|██████████| 1458/1458 [00:07<00:00, 200.59it/s]
Training Generator:   8%|▊         | 114/1458 [00:00<00:11, 120.49it/s]

Batch 100: Recon Loss: 65365.7500, Adv Loss: 17.7607, Fair Loss: 1.3139


Training Generator:  15%|█▍        | 218/1458 [00:01<00:10, 120.94it/s]

Batch 200: Recon Loss: 64794.8398, Adv Loss: 18.0883, Fair Loss: 1.1002


Training Generator:  22%|██▏       | 320/1458 [00:02<00:09, 119.63it/s]

Batch 300: Recon Loss: 67286.6797, Adv Loss: 17.7737, Fair Loss: 1.2254


Training Generator:  29%|██▊       | 418/1458 [00:03<00:08, 119.02it/s]

Batch 400: Recon Loss: 67558.9766, Adv Loss: 17.7110, Fair Loss: 1.2110


Training Generator:  35%|███▌      | 515/1458 [00:04<00:08, 111.78it/s]

Batch 500: Recon Loss: 63990.6875, Adv Loss: 17.5081, Fair Loss: 1.2865


Training Generator:  42%|████▏     | 618/1458 [00:05<00:07, 119.72it/s]

Batch 600: Recon Loss: 62137.9492, Adv Loss: 17.7746, Fair Loss: 1.3072


Training Generator:  49%|████▉     | 711/1458 [00:06<00:08, 90.93it/s] 

Batch 700: Recon Loss: 63440.4883, Adv Loss: 18.1203, Fair Loss: 1.3375


Training Generator:  56%|█████▌    | 816/1458 [00:07<00:07, 89.62it/s]

Batch 800: Recon Loss: 66418.5781, Adv Loss: 19.1419, Fair Loss: 1.2600


Training Generator:  63%|██████▎   | 921/1458 [00:08<00:05, 103.79it/s]

Batch 900: Recon Loss: 65858.4375, Adv Loss: 17.7233, Fair Loss: 1.3866


Training Generator:  70%|██████▉   | 1017/1458 [00:09<00:03, 115.54it/s]

Batch 1000: Recon Loss: 65046.4883, Adv Loss: 18.9069, Fair Loss: 1.3564


Training Generator:  76%|███████▋  | 1115/1458 [00:10<00:02, 116.09it/s]

Batch 1100: Recon Loss: 66116.3359, Adv Loss: 17.9465, Fair Loss: 1.1511


Training Generator:  83%|████████▎ | 1214/1458 [00:10<00:02, 119.26it/s]

Batch 1200: Recon Loss: 64752.3203, Adv Loss: 17.4905, Fair Loss: 1.2976


Training Generator:  90%|█████████ | 1318/1458 [00:11<00:01, 119.97it/s]

Batch 1300: Recon Loss: 67085.5547, Adv Loss: 17.9099, Fair Loss: 1.3977


Training Generator:  97%|█████████▋| 1420/1458 [00:12<00:00, 121.94it/s]

Batch 1400: Recon Loss: 63907.9492, Adv Loss: 18.1865, Fair Loss: 1.3008


Training Generator: 100%|██████████| 1458/1458 [00:13<00:00, 112.02it/s]


Epoch Summary - Attacker Loss: 0.2465, Fair Loss: 0.2533, Generator Loss: 65481.9356


KeyboardInterrupt: 

# Evaluation on test

In [None]:
# Required imports for evaluation
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    hamming_loss,
    precision_recall_curve
)
from collections import defaultdict

# Make sure plots are displayed properly in Kaggle
%matplotlib inline
plt.style.use('seaborn-whitegrid')

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

def find_optimal_threshold(y_true, y_pred):
    """
    Find the threshold that maximizes F1 score for binary classification
    by calculating precision and recall at each threshold
    
    Args:
        y_true: Ground truth labels (numpy array)
        y_pred: Predicted probabilities (numpy array)
        
    Returns:
        Optimal threshold value
    """
    # Handle case where all labels are same class (all 0s or all 1s)
    if np.all(y_true == 0) or np.all(y_true == 1):
        return 0.5  # Default to 0.5 if all labels are the same
    
    # Get precision-recall pairs for different thresholds
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    
    # Calculate F1 score for each precision-recall pair
    # F1 = 2 * (precision * recall) / (precision + recall)
    # Add a small epsilon to avoid division by zero
    epsilon = 1e-7
    f1_scores = 2 * (precision * recall) / (precision + recall + epsilon)
    
    # Find the threshold with the best F1 score
    # Note: precision_recall_curve returns one more precision/recall value than thresholds
    # So we need to handle this edge case
    if len(f1_scores) == len(thresholds) + 1:
        # Use all but the last f1_score value
        best_idx = np.argmax(f1_scores[:-1])
    else:
        best_idx = np.argmax(f1_scores)
        
    best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5
    best_f1 = f1_scores[best_idx]
    
    # Print out the precision, recall, and F1 at the optimal threshold
    if best_idx < len(precision) and best_idx < len(recall):
        best_precision = precision[best_idx]
        best_recall = recall[best_idx]
        print(f"  Optimal threshold: {best_threshold:.3f}, F1: {best_f1:.3f}, Precision: {best_precision:.3f}, Recall: {best_recall:.3f}")
    
    return best_threshold

def calculate_multilabel_metrics(y_true, y_pred, y_score=None):
    """
    Calculate metrics for multilabel classification
    
    Args:
        y_true: True labels
        y_pred: Predicted binary labels
        y_score: Prediction scores (probabilities)
    
    Returns:
        Dictionary of metrics
    """
    metrics = {}
    
    # Sample-averaged metrics
    metrics["accuracy"] = accuracy_score(y_true, y_pred)
    
    try:
        metrics["precision"] = precision_score(y_true, y_pred, average='samples', zero_division=0)
        metrics["recall"] = recall_score(y_true, y_pred, average='samples', zero_division=0)
        metrics["f1"] = f1_score(y_true, y_pred, average='samples', zero_division=0)
    except:
        # Fallback to macro average if samples doesn't work
        metrics["precision"] = precision_score(y_true, y_pred, average='macro', zero_division=0)
        metrics["recall"] = recall_score(y_true, y_pred, average='macro', zero_division=0)
        metrics["f1"] = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    # Calculate Hamming loss (fraction of incorrect labels)
    metrics["hamming_loss"] = hamming_loss(y_true, y_pred)
    
    # Calculate AUC if scores are provided
    if y_score is not None:
        try:
            metrics["macro_auc"] = roc_auc_score(y_true, y_score, average='macro')
            metrics["micro_auc"] = roc_auc_score(y_true, y_score, average='micro')
        except:
            print("Warning: Could not calculate AUC (possibly due to single-class issues)")
    
    # Print out metrics
    for metric, value in metrics.items():
        print(f"  {metric}: {value:.4f}")
    
    return metrics

def calculate_tpr(y_true, y_pred):
    """
    Calculate True Positive Rate (Recall/Sensitivity) across all labels
    
    Args:
        y_true: True labels
        y_pred: Predicted binary labels
    
    Returns:
        TPR value
    """
    # Calculate TPR for each label
    tpr_by_label = []
    
    for i in range(y_true.shape[1]):
        y_true_i = y_true[:, i]
        y_pred_i = y_pred[:, i]
        
        # Only calculate if there are positive examples
        if np.sum(y_true_i) > 0:
            # True positives / (True positives + False negatives)
            tp = np.sum((y_true_i == 1) & (y_pred_i == 1))
            fn = np.sum((y_true_i == 1) & (y_pred_i == 0))
            
            if tp + fn > 0:
                tpr_by_label.append(tp / (tp + fn))
    
    # Return macro-averaged TPR
    if len(tpr_by_label) > 0:
        return np.mean(tpr_by_label)
    else:
        return 0.0

def evaluate_with_demographics(test_loader, generator, attacker_model, fair_model, device):
    """
    Evaluate models on test data with demographic subgroup analysis.
    Compares performance of original biased embeddings vs debiased embeddings.
    Uses optimal thresholds to maximize F1 score for each label.
    
    Args:
        test_loader: DataLoader for test data
        generator: Trained generator model
        attacker_model: Trained attacker model
        fair_model: Trained fair model
        device: Computation device
    """
    print("Starting evaluation with demographic analysis...")
    
    # Set models to evaluation mode
    generator.eval()
    attacker_model.eval()
    fair_model.eval()
    
    # Initialize collectors for predictions and metadata
    all_preds_original = []  # Predictions using original embeddings
    all_preds_debiased = []  # Predictions using debiased embeddings (fair model)
    all_labels = []
    all_demographics = {
        'gender': [],
        'insurance': [],
        'race': [],
        'anchor_age': []
    }
    all_subject_ids = []
    
    # Collect all predictions and demographics
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating on test data"):
            # Get batch data
            embeddings = batch['embedding'].to(device)
            labels = batch['labels'].to(device)
            demographics = batch['demographics']
            subject_ids = batch['subject_id']
            
            # Generate debiased embeddings
            reconstructed, _, _, _ = generator(embeddings)
            
            # Get predictions
            original_preds = torch.sigmoid(attacker_model(embeddings))
            debiased_preds = torch.sigmoid(fair_model(reconstructed))
            
            # Store predictions and metadata
            all_preds_original.append(original_preds.cpu().numpy())
            all_preds_debiased.append(debiased_preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
            # Store demographics
            for demo_key in all_demographics.keys():
                if demo_key in demographics:
                    all_demographics[demo_key].extend(demographics[demo_key])
            
            all_subject_ids.extend(subject_ids)
    
    # Convert lists to numpy arrays for easier processing
    all_preds_original = np.concatenate(all_preds_original, axis=0)
    all_preds_debiased = np.concatenate(all_preds_debiased, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    # Find optimal threshold for each label
    print("\n=== FINDING OPTIMAL THRESHOLDS ===")
    num_labels = all_labels.shape[1]
    original_thresholds = np.zeros(num_labels)
    debiased_thresholds = np.zeros(num_labels)
    
    for i in range(num_labels):
        print(f"\nLabel {i+1}:")
        
        # Get data for this label
        label_true = all_labels[:, i]
        orig_pred = all_preds_original[:, i]
        debias_pred = all_preds_debiased[:, i]
        
        # Calculate class distribution
        pos_rate = np.mean(label_true)
        print(f"  Class distribution: {pos_rate*100:.1f}% positive, {(1-pos_rate)*100:.1f}% negative")
        
        # Find optimal thresholds
        print("  Original model:")
        original_thresholds[i] = find_optimal_threshold(label_true, orig_pred)
        
        print("  Debiased model:")
        debiased_thresholds[i] = find_optimal_threshold(label_true, debias_pred)
    
    # Apply optimal thresholds to get binary predictions
    binary_preds_original = np.zeros_like(all_preds_original, dtype=int)
    binary_preds_debiased = np.zeros_like(all_preds_debiased, dtype=int)
    
    for i in range(num_labels):
        binary_preds_original[:, i] = (all_preds_original[:, i] >= original_thresholds[i]).astype(int)
        binary_preds_debiased[:, i] = (all_preds_debiased[:, i] >= debiased_thresholds[i]).astype(int)
    
    # Overall performance metrics with optimal thresholds
    print("\n=== OVERALL PERFORMANCE WITH OPTIMAL THRESHOLDS ===")
    
    print("\nOriginal Model Performance:")
    original_metrics = calculate_multilabel_metrics(all_labels, binary_preds_original, all_preds_original)
    
    print("\nDebiased Model Performance:")
    debiased_metrics = calculate_multilabel_metrics(all_labels, binary_preds_debiased, all_preds_debiased)
    
    # Calculate performance gap between original and debiased
    print("\nPerformance Change (Debiased - Original):")
    for metric in original_metrics:
        if metric in debiased_metrics:
            change = debiased_metrics[metric] - original_metrics[metric]
            print(f"  {metric}: {change:.4f}")
    
    # Demographic subgroup analysis
    print("\n\n=== DEMOGRAPHIC SUBGROUP ANALYSIS ===")
    
    # Analyze each demographic category
    for demo_key in all_demographics.keys():
        if len(all_demographics[demo_key]) == 0:
            print(f"\nSkipping {demo_key} - No data available")
            continue
            
        print(f"\n--- {demo_key.upper()} ANALYSIS ---")
        
        # Get unique values for this demographic
        unique_values = np.unique(all_demographics[demo_key])
        
        # Skip if only one value (no comparison needed)
        if len(unique_values) <= 1:
            print(f"  Only one value for {demo_key}: {unique_values[0]}")
            continue
        
        # Calculate TPR for each subgroup
        tpr_results = {}
        
        for model_name, binary_preds in [("Original", binary_preds_original), 
                                         ("Debiased", binary_preds_debiased)]:
            tpr_results[model_name] = {}
            
            # Calculate overall TPR first
            overall_tpr = calculate_tpr(all_labels, binary_preds)
            tpr_results[model_name]["Overall"] = overall_tpr
            
            # Calculate TPR for each demographic subgroup
            for value in unique_values:
                # Create mask for this subgroup
                mask = np.array(all_demographics[demo_key]) == value
                
                # Skip if too few samples
                if np.sum(mask) < 10:
                    print(f"  Skipping {demo_key}={value} (insufficient samples: {np.sum(mask)})")
                    continue
                
                # Calculate TPR for this subgroup
                subgroup_tpr = calculate_tpr(all_labels[mask], binary_preds[mask])
                tpr_results[model_name][value] = subgroup_tpr
        
        # Display TPR results and disparities
        print(f"\nTrue Positive Rate (TPR) by {demo_key} subgroups:")
        print(f"{'Subgroup':<15} {'Original TPR':<15} {'Debiased TPR':<15} {'Difference':<15}")
        print("-" * 60)
        
        # First show overall
        orig_overall = tpr_results["Original"]["Overall"]
        deb_overall = tpr_results["Debiased"]["Overall"]
        diff = deb_overall - orig_overall
        print(f"{'Overall':<15} {orig_overall:.4f}{'':<9} {deb_overall:.4f}{'':<9} {diff:.4f}{'':<9}")
        
        # Then show each subgroup
        for value in unique_values:
            if value in tpr_results["Original"] and value in tpr_results["Debiased"]:
                orig_tpr = tpr_results["Original"][value]
                deb_tpr = tpr_results["Debiased"][value]
                diff = deb_tpr - orig_tpr
                print(f"{str(value)[:13]:<15} {orig_tpr:.4f}{'':<9} {deb_tpr:.4f}{'':<9} {diff:.4f}{'':<9}")
        
        # Calculate and display TPR disparity (max difference between any two groups)
        print("\nTPR Disparity (max difference between subgroups):")
        
        # Calculate disparities
        orig_values = [v for k, v in tpr_results["Original"].items() if k != "Overall"]
        deb_values = [v for k, v in tpr_results["Debiased"].items() if k != "Overall"]
        
        if len(orig_values) >= 2 and len(deb_values) >= 2:
            orig_disparity = max(orig_values) - min(orig_values)
            deb_disparity = max(deb_values) - min(deb_values)
            diff = deb_disparity - orig_disparity
            
            print(f"  Original model disparity: {orig_disparity:.4f}")
            print(f"  Debiased model disparity: {deb_disparity:.4f}")
            print(f"  Disparity change: {diff:.4f} ({'reduced' if diff < 0 else 'increased'})")
        else:
            print("  Not enough subgroups to calculate disparity")
    
    # Calculate per-label F1 scores
    print("\n=== PER-LABEL F1 SCORES ===")
    original_f1 = []
    debiased_f1 = []
    
    for i in range(num_labels):
        label_true = all_labels[:, i]
        orig_pred = binary_preds_original[:, i]
        debias_pred = binary_preds_debiased[:, i]
        
        orig_f1 = f1_score(label_true, orig_pred, zero_division=0)
        deb_f1 = f1_score(label_true, debias_pred, zero_division=0)
        
        original_f1.append(orig_f1)
        debiased_f1.append(deb_f1)
        
        print(f"Label {i+1}:")
        print(f"  Original F1: {orig_f1:.4f} (threshold: {original_thresholds[i]:.3f})")
        print(f"  Debiased F1: {deb_f1:.4f} (threshold: {debiased_thresholds[i]:.3f})")
        print(f"  F1 Change: {deb_f1 - orig_f1:.4f}")
    
    print(f"\nAverage Original F1: {np.mean(original_f1):.4f}")
    print(f"Average Debiased F1: {np.mean(debiased_f1):.4f}")
    print(f"Average F1 Change: {np.mean(debiased_f1) - np.mean(original_f1):.4f}")
    
    # Return results for potential further analysis
    return {
        "original_preds": all_preds_original,
        "debiased_preds": all_preds_debiased,
        "labels": all_labels,
        "demographics": all_demographics,
        "subject_ids": all_subject_ids,
        "original_thresholds": original_thresholds,
        "debiased_thresholds": debiased_thresholds,
        "binary_preds_original": binary_preds_original,
        "binary_preds_debiased": binary_preds_debiased
    }

# Run the evaluation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running evaluation on device: {device}")
print(f"Test data has {len(test_loader.dataset)} samples")

evaluation_results = evaluate_with_demographics(
    test_loader=test_loader,
    generator=generator,
    attacker_model=attacker_model,
    fair_model=fair_model,
    device=device
)

# Visualize TPR by demographic subgroups
plt.figure(figsize=(12, 8))

# Choose which demographic to visualize
demo_key = 'race'  # Change to 'gender', 'insurance', or 'anchor_age' as needed
if len(evaluation_results['demographics'][demo_key]) > 0:
    # Get unique subgroups
    unique_values = np.unique(evaluation_results['demographics'][demo_key])
    
    # Calculate TPR for original and debiased models
    original_tprs = []
    debiased_tprs = []
    labels = []
    
    # Overall TPR first
    orig_overall = calculate_tpr(evaluation_results['labels'], 
                                evaluation_results['binary_preds_original'])
    deb_overall = calculate_tpr(evaluation_results['labels'], 
                               evaluation_results['binary_preds_debiased'])
    
    original_tprs.append(orig_overall)
    debiased_tprs.append(deb_overall)
    labels.append('Overall')
    
    # Calculate for each subgroup
    for value in unique_values:
        # Create mask for this subgroup
        mask = np.array(evaluation_results['demographics'][demo_key]) == value
        
        # Skip if too few samples
        if np.sum(mask) < 10:
            continue
            
        # Calculate TPR
        orig_tpr = calculate_tpr(
            evaluation_results['labels'][mask], 
            evaluation_results['binary_preds_original'][mask]
        )
        
        deb_tpr = calculate_tpr(
            evaluation_results['labels'][mask], 
            evaluation_results['binary_preds_debiased'][mask]
        )
        
        original_tprs.append(orig_tpr)
        debiased_tprs.append(deb_tpr)
        labels.append(str(value))
    
    # Create bar chart
    x = np.arange(len(labels))
    width = 0.35
    
    plt.bar(x - width/2, original_tprs, width, label='Original Model')
    plt.bar(x + width/2, debiased_tprs, width, label='Debiased Model')
    
    plt.ylabel('True Positive Rate (TPR)')
    plt.title(f'TPR Comparison by {demo_key.capitalize()} Subgroups')
    plt.xticks(x, labels, rotation=45 if len(labels) > 4 else 0)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'tpr_comparison_{demo_key}.png')
    plt.show()
else:
    print(f"No data available for demographic: {demo_key}")

# Create a visualization of optimal thresholds
plt.figure(figsize=(10, 6))
x = np.arange(len(evaluation_results['original_thresholds']))
width = 0.35

plt.bar(x - width/2, evaluation_results['original_thresholds'], width, 
        label='Original Model')
plt.bar(x + width/2, evaluation_results['debiased_thresholds'], width,
        label='Debiased Model')

plt.xlabel('Label Index')
plt.ylabel('Optimal Threshold')
plt.title('Optimal F1 Thresholds by Label')
plt.xticks(x, [f"Label {i+1}" for i in x])
plt.legend()

plt.tight_layout()
plt.savefig('optimal_thresholds.png')
plt.show()

# Save the evaluation results
np.save('evaluation_results.npy', evaluation_results)
print("Evaluation completed and results saved.")