# Data

In [None]:
import torch
import numpy as np
import os
import pickle
import pandas as pd
from torch.utils.data import Dataset, DataLoader

# Path to processed data
data_path = "/kaggle/input/mimic-embedding/processed_mimic_data"

class ProcessedMIMICDataset(Dataset):
    def __init__(self, data_path, data_format='pt'):
        """
        Load a processed MIMIC dataset
        
        Args:
            data_path: Path to the processed data directory
            data_format: Format to load ('pt' for PyTorch tensors, 'npy' for NumPy, 'pkl' for Pickle)
        """
        self.data_path = data_path
        self.data_format = data_format
        
        if data_format == 'pkl':
            # Load the pickle file
            with open(os.path.join(data_path, "all_data.pkl"), 'rb') as f:
                self.all_data = pickle.load(f)
            
            self.embeddings = self.all_data['embeddings']
            self.labels = self.all_data['labels']
            self.subject_ids = self.all_data['subject_ids']
            self.study_ids = self.all_data['study_ids']
            self.demographics = self.all_data['demographics']
            
        elif data_format == 'pt':
            # Load PyTorch tensors with weights_only=True to avoid security warnings
            self.embeddings = torch.load(os.path.join(data_path, "embeddings.pt"), weights_only=True)
            self.labels = torch.load(os.path.join(data_path, "labels.pt"), weights_only=True)
            
            # Load IDs from CSV
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            
            # Load demographics
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                self.demographics = pickle.load(f)
                
        elif data_format == 'npy':
            # Load NumPy arrays
            self.embeddings = np.load(os.path.join(data_path, "embeddings.npy"))
            self.labels = np.load(os.path.join(data_path, "labels.npy"))
            
            # Load IDs from CSV
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            
            # Load demographics
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                self.demographics = pickle.load(f)
                
        else:
            raise ValueError(f"Unsupported data format: {data_format}")
        
        print(f"Loaded dataset from {data_path} with {len(self.embeddings)} samples")
        
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        if self.data_format == 'pt':
            embedding = self.embeddings[idx]
            labels = self.labels[idx]
        else:
            embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
            labels = torch.tensor(self.labels[idx], dtype=torch.float32)
            
        return {
            'embedding': embedding,
            'labels': labels,
            'subject_id': self.subject_ids[idx],
            'study_id': self.study_ids[idx],
            'demographics': self.demographics[idx]
        }

# Load the datasets
train_dataset = ProcessedMIMICDataset(os.path.join(data_path, "train"), data_format='pt')
test_dataset = ProcessedMIMICDataset(os.path.join(data_path, "test"), data_format='pt')

# Create validation set from train (if needed)
def create_train_val_split(train_dataset, val_ratio=0.1, random_seed=42):
    """Split training dataset into train and validation sets"""
    dataset_size = len(train_dataset)
    val_size = int(val_ratio * dataset_size)
    train_size = dataset_size - val_size
    
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(random_seed)
    )
    
    print(f"Split train dataset: {train_size} training samples, {val_size} validation samples")
    
    return train_subset, val_subset

# Create train/val split (optional)
train_subset, val_subset = create_train_val_split(train_dataset)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Display dataset statistics
print(f"Training samples: {len(train_subset)}")
print(f"Validation samples: {len(val_subset)}")
print(f"Test samples: {len(test_dataset)}")

# Example of accessing a batch
sample_batch = next(iter(train_loader))
print(f"Sample batch shapes:")
print(f"  Embeddings: {sample_batch['embedding'].shape}")
print(f"  Labels: {sample_batch['labels'].shape}")
print(f"  Batch keys {sample_batch.keys()}")
print(f"  Batch[demographic] keys {sample_batch['demographics'].keys()}")

In [None]:
len(pd.Series(sample_batch['demographics']['gender']).unique()),len(pd.Series(sample_batch['demographics']['insurance']).unique()),len(pd.Series(sample_batch['demographics']['race']).unique()),

# LWBC Test

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, random_split
import numpy as np
import random
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import roc_auc_score

# The MIMICClassifier from your code
class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        """
        Simple feed-forward neural network for multi-label classification
        
        Args:
            input_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            output_dim: Number of output classes
            dropout_rate: Dropout probability for regularization
        """
        super(MIMICClassifier, self).__init__()
        
        # Create the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        
        # Hidden layers
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Learning with Biased Committee implementation
class LWBC:
    def __init__(
        self,
        train_loader,
        val_loader,
        input_dim=1376,
        hidden_dims=[512, 256, 128],
        output_dim=14,
        committee_size=30,
        subset_size=0.7,
        alpha=0.02,
        lambda_kd=0.6,
        temperature=2.0,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        """
        LWBC implementation for MIMIC dataset.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            input_dim: Input embedding dimension
            hidden_dims: Hidden layer dimensions
            output_dim: Number of output classes
            committee_size: Number of classifiers in the committee
            subset_size: Size of subset for each committee member (proportion)
            alpha: Scaling parameter for weighting function
            lambda_kd: Balance parameter for knowledge distillation
            temperature: Temperature for knowledge distillation
            device: Device to run computations on
        """
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.committee_size = committee_size
        self.subset_size = subset_size
        self.alpha = alpha
        self.lambda_kd = lambda_kd
        self.temperature = temperature
        self.device = device
        
        # Initialize main classifier
        self.main_classifier = MIMICClassifier(
            input_dim, hidden_dims, output_dim
        ).to(device)
        
        # Initialize committee of auxiliary classifiers
        self.committee = [
            MIMICClassifier(input_dim, hidden_dims, output_dim).to(device)
            for _ in range(committee_size)
        ]
        
        # Initialize optimizers
        self.main_optimizer = optim.Adam(self.main_classifier.parameters(), lr=1e-3)
        self.committee_optimizers = [
            optim.Adam(classifier.parameters(), lr=1e-3)
            for classifier in self.committee
        ]
        
        # Loss function for multi-label classification
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')
        
        # Create subsets for each committee member
        self.create_subsets()
        
    def create_subsets(self):
        """Create random subsets of training data for each committee member"""
        # Count total samples in the training dataset
        dataset_size = len(self.train_loader.dataset)
        subset_count = int(dataset_size * self.subset_size)
        
        # Create subsets for each committee member
        self.subsets = []
        for _ in range(self.committee_size):
            # Sample indices with replacement
            subset_indices = random.choices(range(dataset_size), k=subset_count)
            self.subsets.append(sorted(subset_indices))
    
    def compute_sample_weights(self, outputs, labels):
        """
        Compute sample weights based on committee consensus
        
        Args:
            outputs: List of outputs from committee members
            labels: Ground truth labels
            
        Returns:
            Tensor of sample weights
        """
        batch_size = labels.size(0)
        correct_count = torch.zeros(batch_size, device=self.device)
        
        # Count correct predictions for each sample across committee
        for output in outputs:
            preds = (torch.sigmoid(output) > 0.25).float()
            correct = (preds == labels).all(dim=1).float()
            correct_count += correct
        
        # Calculate proportion of committee members that predicted correctly
        correct_proportion = correct_count / self.committee_size
        
        # Apply weighting function w(x) = 1 / (proportion + alpha)
        weights = 1.0 / (correct_proportion + self.alpha)
        
        return weights
    
    def knowledge_distillation_loss(self, committee_output, main_output):
        """
        Calculate knowledge distillation loss
        
        Args:
            committee_output: Output from committee member
            main_output: Output from main classifier
            
        Returns:
            KD loss
        """
        # Apply temperature scaling
        committee_logits = committee_output / self.temperature
        main_logits = main_output / self.temperature
        
        # KL divergence between softmax distributions
        committee_probs = torch.sigmoid(committee_logits)
        main_probs = torch.sigmoid(main_logits)
        
        # Calculate KL divergence for each output dimension
        kl_div = main_probs * torch.log(main_probs / (committee_probs + 1e-8) + 1e-8) + \
                (1 - main_probs) * torch.log((1 - main_probs) / (1 - committee_probs + 1e-8) + 1e-8)
                
        return kl_div.mean()
    
    def train_committee_warmup(self, num_epochs=5):
        """Warm-up training for committee members"""
        print("Starting committee warm-up training...")
        
        # Create sample masks for committee members
        committee_masks = []
        dataset_size = len(self.train_loader.dataset)
        for subset_indices in self.subsets:
            mask = torch.zeros(dataset_size, dtype=torch.bool)
            mask[subset_indices] = True
            committee_masks.append(mask)
        
        for classifier_idx, classifier in enumerate(self.committee):
            classifier.train()
            optimizer = self.committee_optimizers[classifier_idx]
            mask = committee_masks[classifier_idx]
            
            print(f"Training committee member {classifier_idx+1}/{self.committee_size}")
            
            for epoch in range(num_epochs):
                total_loss = 0.0
                num_batches = 0
                
                # Create a subset DataLoader
                subset_dataset = Subset(self.train_loader.dataset, self.subsets[classifier_idx])
                subset_loader = DataLoader(
                    subset_dataset, 
                    batch_size=self.train_loader.batch_size,
                    shuffle=True,
                    num_workers=4
                )
                
                for batch in subset_loader:
                    # Extract data
                    X = batch['embedding'].to(self.device)
                    y = batch['labels'].to(self.device)
                    
                    # Forward pass
                    outputs = classifier(X)
                    loss = self.criterion(outputs, y).mean()
                    
                    # Backward pass and optimize
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    num_batches += 1
                
                avg_loss = total_loss / max(1, num_batches)
                print(f"  Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    def train_epoch(self):
        """Train for one epoch"""
        self.main_classifier.train()
        for classifier in self.committee:
            classifier.train()
        
        total_main_loss = 0.0
        total_committee_loss = 0.0
        num_batches = 0
        
        # Calculate sample weights for all training data
        print("Computing sample weights based on committee consensus...")
        sample_weights_dict = {}  # Map batch_idx -> sample_weights
        
        # Create mapping from global index to subset inclusion
        dataset_size = len(self.train_loader.dataset)
        subset_masks = []
        for subset in self.subsets:
            mask = torch.zeros(dataset_size, dtype=torch.bool)
            mask[subset] = True
            subset_masks.append(mask)
            
        # Get weights for all samples
        global_sample_weights = torch.ones(dataset_size, device=self.device)
        
        # Since we can't fit all data at once, process in batches
        batch_size = self.train_loader.batch_size
        
        # First, get predictions from committee for all samples
        with torch.no_grad():
            all_committee_preds = [[] for _ in range(self.committee_size)]
            all_labels = []
            batch_offsets = []
            
            # Get predictions for the entire dataset
            offset = 0
            for batch in tqdm(self.train_loader, desc="Getting committee predictions"):
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                batch_offsets.append(offset)
                
                # Get predictions from each committee member
                for i, classifier in enumerate(self.committee):
                    preds = classifier(X)
                    all_committee_preds[i].append(preds)
                
                all_labels.append(y)
                offset += len(X)
            
            # Determine weights for each sample
            for batch_idx, offset in enumerate(batch_offsets):
                y = all_labels[batch_idx]
                batch_size = y.size(0)
                
                # Count correct predictions for each sample
                correct_count = torch.zeros(batch_size, device=self.device)
                for i in range(self.committee_size):
                    output = all_committee_preds[i][batch_idx]
                    preds = (torch.sigmoid(output) > 0.25).float()
                    correct = (preds == y).all(dim=1).float()
                    correct_count += correct
                
                # Calculate proportion of committee members that predicted correctly
                correct_proportion = correct_count / self.committee_size
                
                # Apply weighting function w(x) = 1 / (proportion + alpha)
                weights = 1.0 / (correct_proportion + self.alpha)
                
                # Store for later use
                sample_weights_dict[batch_idx] = weights
        
        # Main training loop
        print("Training main classifier with weighted samples...")
        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc="Training main classifier")):
            X = batch['embedding'].to(self.device)
            y = batch['labels'].to(self.device)
            batch_size = X.size(0)
            
            # Get sample weights for this batch
            sample_weights = sample_weights_dict[batch_idx]
            
            # Train main classifier with weighted loss
            self.main_optimizer.zero_grad()
            main_outputs = self.main_classifier(X)
            main_loss = (self.criterion(main_outputs, y) * sample_weights.unsqueeze(1)).mean()
            main_loss.backward()
            self.main_optimizer.step()
            
            total_main_loss += main_loss.item()
            num_batches += 1
        
        # Train committee members with knowledge distillation
        print("Training committee with knowledge distillation...")
        committee_loss_sum = 0.0
        committee_batches = 0
        
        for classifier_idx, classifier in enumerate(self.committee):
            classifier.train()
            optimizer = self.committee_optimizers[classifier_idx]
            
            # Create subset dataset for this committee member
            subset_dataset = Subset(self.train_loader.dataset, self.subsets[classifier_idx])
            subset_loader = DataLoader(
                subset_dataset,
                batch_size=self.train_loader.batch_size,
                shuffle=True,
                num_workers=4
            )
            
            # Create complement dataset for knowledge distillation
            complement_indices = [i for i in range(dataset_size) if i not in self.subsets[classifier_idx]]
            complement_dataset = Subset(self.train_loader.dataset, complement_indices)
            complement_loader = DataLoader(
                complement_dataset,
                batch_size=self.train_loader.batch_size,
                shuffle=True,
                num_workers=4
            )
            
            # Train on subset with supervised loss
            for batch in subset_loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = classifier(X)
                ce_loss = self.criterion(outputs, y).mean()
                ce_loss.backward()
                optimizer.step()
            
            # Train on complement with knowledge distillation
            for batch in complement_loader:
                X = batch['embedding'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                committee_output = classifier(X)
                
                # Get main classifier output
                with torch.no_grad():
                    main_output = self.main_classifier(X)
                
                # Knowledge distillation loss
                kd_loss = self.knowledge_distillation_loss(committee_output, main_output)
                kd_loss.backward()
                optimizer.step()
                
                committee_loss_sum += kd_loss.item()
                committee_batches += 1
        
        avg_committee_loss = committee_loss_sum / max(1, committee_batches)
        
        return total_main_loss / num_batches, avg_committee_loss
    
    def evaluate(self, loader):
        """Evaluate model on validation or test set"""
        self.main_classifier.eval()
        all_outputs = []
        all_labels = []
        
        with torch.no_grad():
            for batch in loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                outputs = self.main_classifier(X)
                all_outputs.append(outputs.cpu())
                all_labels.append(y.cpu())
        
        all_outputs = torch.cat(all_outputs, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        
        # Compute ROC-AUC score
        all_outputs_sigmoid = torch.sigmoid(all_outputs).numpy()
        all_labels_numpy = all_labels.numpy()
        
        # Handle case where a class might have all 0s or all 1s
        auc_scores = []
        for i in range(self.output_dim):
            if len(np.unique(all_labels_numpy[:, i])) > 1:
                auc_scores.append(roc_auc_score(all_labels_numpy[:, i], all_outputs_sigmoid[:, i]))
            else:
                auc_scores.append(0.25)  # Default for single-class
        
        macro_auc = np.mean(auc_scores)
        
        # Compute accuracy
        predictions = (torch.sigmoid(all_outputs) > 0.5).float()
        accuracy = (predictions == all_labels).float().mean().item()
        
        return accuracy, macro_auc, auc_scores
    
    def evaluate_by_demographic(self, loader, demographic_attr='gender'):
        """Evaluate model performance across demographic groups"""
        self.main_classifier.eval()
        results = {}
        
        # Initialize dictionaries to store outputs and labels for each group
        group_outputs = {}
        group_labels = {}
        
        with torch.no_grad():
            for batch in loader:
                X = batch['embedding'].to(self.device)
                y = batch['labels'].to(self.device)
                
                # Get demographic attributes
                demographic_values = batch['demographics'][demographic_attr]
                
                # Get model outputs
                outputs = self.main_classifier(X)
                
                # Group by demographic
                for i, demo_val in enumerate(demographic_values):
                    demo_val = demo_val.item() if isinstance(demo_val, torch.Tensor) else demo_val
                    if demo_val not in group_outputs:
                        group_outputs[demo_val] = []
                        group_labels[demo_val] = []
                    
                    group_outputs[demo_val].append(outputs[i:i+1].cpu())
                    group_labels[demo_val].append(y[i:i+1].cpu())
        
        # Calculate metrics for each group
        for group in group_outputs:
            group_output_tensor = torch.cat(group_outputs[group], dim=0)
            group_label_tensor = torch.cat(group_labels[group], dim=0)
            
            # Compute ROC-AUC
            group_output_sigmoid = torch.sigmoid(group_output_tensor).numpy()
            group_label_numpy = group_label_tensor.numpy()
            
            # Handle case where a class might have all 0s or all 1s
            auc_scores = []
            for i in range(self.output_dim):
                if len(np.unique(group_label_numpy[:, i])) > 1:
                    auc_scores.append(roc_auc_score(group_label_numpy[:, i], group_output_sigmoid[:, i]))
                else:
                    auc_scores.append(0.25)  # Default for single-class
            
            macro_auc = np.mean(auc_scores)
            
            # Compute accuracy
            predictions = (torch.sigmoid(group_output_tensor) > 0.5).float()
            accuracy = (predictions == group_label_tensor).float().mean().item()
            
            results[group] = {
                'accuracy': accuracy,
                'macro_auc': macro_auc,
                'count': len(group_labels[group])
            }
        
        return results
    
    def train(self, num_epochs=20, warmup_epochs=5):
        """Full training procedure"""
        # Warm-up training for committee
        self.train_committee_warmup(num_epochs=warmup_epochs)
        
        # Main training loop
        best_auc = 0.0
        best_model = None
        history = {
            'train_loss': [],
            'committee_loss': [],
            'val_accuracy': [],
            'val_auc': []
        }
        
        for epoch in range(num_epochs):
            # Train one epoch
            train_loss, committee_loss = self.train_epoch()
            
            # Evaluate
            val_accuracy, val_auc, _ = self.evaluate(self.val_loader)
            
            # Store history
            history['train_loss'].append(train_loss)
            history['committee_loss'].append(committee_loss)
            history['val_accuracy'].append(val_accuracy)
            history['val_auc'].append(val_auc)
            
            # Print metrics
            print(f"Epoch {epoch+1}/{num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Committee Loss: {committee_loss:.4f}")
            print(f"  Val Accuracy: {val_accuracy:.4f}, Val AUC: {val_auc:.4f}")
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                best_model = deepcopy(self.main_classifier.state_dict())
                print(f"  New best model with Val AUC: {val_auc:.4f}")
        
        # Load best model
        if best_model is not None:
            self.main_classifier.load_state_dict(best_model)
        
        return history
    
    def save_model(self, path):
        """Save the model to disk"""
        torch.save(self.main_classifier.state_dict(), path)
    
    def load_model(self, path):
        """Load the model from disk"""
        self.main_classifier.load_state_dict(torch.load(path, map_location=self.device))


In [None]:
# Initialize LWBC
lwbc = LWBC(
    train_loader=train_loader,
    val_loader=val_loader,
    input_dim=1376,
    hidden_dims=[512, 256, 128],
    output_dim=14,       
    committee_size=3,
    subset_size=0.3,
    alpha=0.02,
    lambda_kd=0.6
)

# Train the model
history = lwbc.train(num_epochs=100, warmup_epochs=5)

# Evaluate on test set
test_accuracy, test_auc, _ = lwbc.evaluate(test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}, Test AUC: {test_auc:.4f}")

# Evaluate across demographic groups
gender_results = lwbc.evaluate_by_demographic(test_loader, demographic_attr='gender')
for gender, metrics in gender_results.items():
    print(f"Gender {gender}: Accuracy: {metrics['accuracy']:.4f}, AUC: {metrics['macro_auc']:.4f}, Count: {metrics['count']}")

# Save model
lwbc.save_model('lwbc_mimic_model.pt')

# Downsampling

In [None]:
train_subset[0],val_subset[0],test_dataset[0]
# DownSample Train, if it is is too small - FineTune the biased

In [None]:
import numpy as np
import torch
from collections import defaultdict
import random

def create_fair_dataset(train_subset, random_seed=42):
    """
    Create a fair dataset by ensuring each demographic category 
    and label class has exactly the same count.
    
    Parameters:
    -----------
    train_subset : list
        List of dictionaries containing patient data with PyTorch tensors
    random_seed : int
        Random seed for reproducibility
    
    Returns:
    --------
    list
        A perfectly balanced dataset
    """
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Step 1: First independently downsample each demographic category
    
    # Gender downsampling
    male_samples = [item for item in train_subset if item['demographics']['gender'] == 'M']
    female_samples = [item for item in train_subset if item['demographics']['gender'] == 'F']
    
    min_gender_count = min(len(male_samples), len(female_samples))
    male_samples = random.sample(male_samples, min_gender_count)
    female_samples = random.sample(female_samples, min_gender_count)
    
    balanced_by_gender = male_samples + female_samples
    print(f"After gender balancing: {len(balanced_by_gender)} samples")
    
    # Insurance downsampling
    insurance_groups = defaultdict(list)
    for item in balanced_by_gender:
        insurance_groups[item['demographics']['insurance']].append(item)
    
    min_insurance_count = min(len(group) for group in insurance_groups.values())
    
    balanced_by_insurance = []
    for insurance_type, samples in insurance_groups.items():
        balanced_by_insurance.extend(random.sample(samples, min_insurance_count))
    
    print(f"After insurance balancing: {len(balanced_by_insurance)} samples")
    
    # Race downsampling
    race_groups = defaultdict(list)
    for item in balanced_by_insurance:
        race_groups[item['demographics']['race']].append(item)
    
    min_race_count = min(len(group) for group in race_groups.values())
    
    balanced_by_race = []
    for race, samples in race_groups.items():
        balanced_by_race.extend(random.sample(samples, min_race_count))
    
    print(f"After race balancing: {len(balanced_by_race)} samples")
    
    # Label downsampling
    label_groups = defaultdict(list)
    for item in balanced_by_race:
        if isinstance(item['labels'], torch.Tensor):
            label_key = tuple(item['labels'].cpu().numpy().tolist())
        else:
            label_key = tuple(item['labels'])
        
        label_groups[label_key].append(item)
    
    min_label_count = min(len(group) for group in label_groups.values())
    
    balanced_by_label = []
    for label, samples in label_groups.items():
        balanced_by_label.extend(random.sample(samples, min_label_count))
    
    print(f"After label balancing: {len(balanced_by_label)} samples")
    
    return balanced_by_label

def verify_fairness(dataset):
    """
    Simple function to verify the fairness of the dataset.
    Handles PyTorch tensor data.
    
    Parameters:
    -----------
    dataset : list
        The dataset to verify
    
    Returns:
    --------
    dict
        Distribution statistics
    """
    # Count demographics
    gender_count = defaultdict(int)
    insurance_count = defaultdict(int)
    race_count = defaultdict(int)
    label_count = defaultdict(int)
    
    for item in dataset:
        gender_count[item['demographics']['gender']] += 1
        insurance_count[item['demographics']['insurance']] += 1
        race_count[item['demographics']['race']] += 1
        
        # Handle PyTorch tensors properly
        labels_tensor = item['labels']
        if isinstance(labels_tensor, torch.Tensor):
            # Convert to a hashable format
            label_tuple = tuple(labels_tensor.cpu().numpy().tolist())
        else:
            label_tuple = tuple(labels_tensor)
            
        label_count[label_tuple] += 1
    
    return {
        'gender': dict(gender_count),
        'insurance': dict(insurance_count),
        'race': dict(race_count),
        'labels': dict(label_count),
        'total_samples': len(dataset)
    }

# Example usage:
fair_train_subset = create_fair_dataset(train_subset)
fairness_metrics = verify_fairness(fair_train_subset)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
import time
from tqdm import tqdm
import os

# Define the neural network model
class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        """
        Simple feed-forward neural network for multi-label classification
        
        Args:
            input_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            output_dim: Number of output classes
            dropout_rate: Dropout probability for regularization
        """
        super(MIMICClassifier, self).__init__()
        
        # Create the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        
        # Hidden layers
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        # No activation function here since we're using BCEWithLogitsLoss
        # which combines sigmoid and binary cross entropy
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import roc_auc_score

def finetune_on_fair_dataset(model, fair_train_subset, val_subset, num_epochs=20, 
                            learning_rate=0.0005, weight_decay=1e-4, batch_size=32):
    train_embeddings = []
    train_labels = []
    
    for item in fair_train_subset:
        if isinstance(item['embedding'], torch.Tensor):
            train_embeddings.append(item['embedding'])
        else:
            train_embeddings.append(torch.tensor(item['embedding']))
            
        if isinstance(item['labels'], torch.Tensor):
            train_labels.append(item['labels'])
        else:
            train_labels.append(torch.tensor(item['labels']))
    
    train_embeddings = torch.stack(train_embeddings)
    train_labels = torch.stack(train_labels)
    
    train_dataset = TensorDataset(train_embeddings, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    if isinstance(val_subset, DataLoader):
        val_loader = val_subset
        val_dataset = val_subset.dataset
    else:
        val_embeddings = []
        val_labels = []
        
        for item in val_subset:
            if isinstance(item['embedding'], torch.Tensor):
                val_embeddings.append(item['embedding'])
            else:
                val_embeddings.append(torch.tensor(item['embedding']))
                
            if isinstance(item['labels'], torch.Tensor):
                val_labels.append(item['labels'])
            else:
                val_labels.append(torch.tensor(item['labels']))
        
        val_embeddings = torch.stack(val_embeddings)
        val_labels = torch.stack(val_labels)
        
        val_dataset = TensorDataset(val_embeddings, val_labels)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_auc': []
    }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    print(f"Fine-tuning on fair dataset with {len(train_dataset)} samples")
    print(f"Using device: {device}")
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
        
        train_loss /= len(train_dataset)
        
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                
                val_preds.append(torch.sigmoid(outputs).cpu().numpy())
                val_targets.append(targets.cpu().numpy())
        
        val_loss /= len(val_dataset)
        
        val_preds = np.vstack(val_preds)
        val_targets = np.vstack(val_targets)
        
        aucs = []
        for i in range(val_targets.shape[1]):
            if np.sum(val_targets[:, i] > 0) > 0 and np.sum(val_targets[:, i] == 0) > 0:
                auc = roc_auc_score(val_targets[:, i], val_preds[:, i])
                aucs.append(auc)
        
        val_auc = np.mean(aucs) if aucs else 0.0
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_auc'].append(val_auc)
        
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
    
    return model, history

hidden_dims = [512, 256, 128]
dropout_rate = 0.3
learning_rate = 0.0005
weight_decay = 1e-4
num_epochs = 20
label_columns = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 
            'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 
            'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 
            'Pleural Other', 'Fracture', 'Support Devices', 'No Finding']

model.load_state_dict(torch.load("/kaggle/input/biased-basline/pytorch/default/1/baseline_classifier_model.pt"))

fair_train_subset = create_fair_dataset(train_subset)

finetuned_model, training_history = finetune_on_fair_dataset(
    model=model,
    fair_train_subset=fair_train_subset,
    val_subset=val_subset,
    learning_rate=0.0005,
    weight_decay=1e-4,
    num_epochs=20
)

torch.save(finetuned_model.state_dict(), "finetuned_fair_model.pt")

# Ensemble

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import roc_auc_score


def train_from_scratch(fair_train_subset, val_subset, input_dim=1376, 
                      hidden_dims=[512, 256, 128], output_dim=14, 
                      learning_rate=0.001, weight_decay=1e-5, 
                      num_epochs=50, batch_size=32, dropout_rate=0.3):
    # Initialize a new model
    model = MIMICClassifier(
        input_dim=input_dim,
        hidden_dims=hidden_dims,
        output_dim=output_dim,
        dropout_rate=dropout_rate
    )
    
    # Process training data
    train_embeddings = []
    train_labels = []
    
    for item in fair_train_subset:
        if isinstance(item['embedding'], torch.Tensor):
            train_embeddings.append(item['embedding'])
        else:
            train_embeddings.append(torch.tensor(item['embedding']))
            
        if isinstance(item['labels'], torch.Tensor):
            train_labels.append(item['labels'])
        else:
            train_labels.append(torch.tensor(item['labels']))
    
    train_embeddings = torch.stack(train_embeddings)
    train_labels = torch.stack(train_labels)
    
    train_dataset = TensorDataset(train_embeddings, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Process validation data
    if isinstance(val_subset, DataLoader):
        val_loader = val_subset
        val_dataset = val_subset.dataset
    else:
        val_embeddings = []
        val_labels = []
        
        for item in val_subset:
            if isinstance(item['embedding'], torch.Tensor):
                val_embeddings.append(item['embedding'])
            else:
                val_embeddings.append(torch.tensor(item['embedding']))
                
            if isinstance(item['labels'], torch.Tensor):
                val_labels.append(item['labels'])
            else:
                val_labels.append(torch.tensor(item['labels']))
        
        val_embeddings = torch.stack(val_embeddings)
        val_labels = torch.stack(val_labels)
        
        val_dataset = TensorDataset(val_embeddings, val_labels)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Loss function, optimizer and scheduler
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_auc': []
    }
    
    # Use GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    print(f"Training new model from scratch on fair dataset with {len(train_dataset)} samples")
    print(f"Using device: {device}")
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
        
        train_loss /= len(train_dataset)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                
                val_preds.append(torch.sigmoid(outputs).cpu().numpy())
                val_targets.append(targets.cpu().numpy())
        
        val_loss /= len(val_dataset)
        
        # Calculate AUC
        val_preds = np.vstack(val_preds)
        val_targets = np.vstack(val_targets)
        
        aucs = []
        for i in range(val_targets.shape[1]):
            if np.sum(val_targets[:, i] > 0) > 0 and np.sum(val_targets[:, i] == 0) > 0:
                auc = roc_auc_score(val_targets[:, i], val_preds[:, i])
                aucs.append(auc)
        
        val_auc = np.mean(aucs) if aucs else 0.0
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_auc'].append(val_auc)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
    
    return model, history

# Example usage with your code structure
hidden_dims = [512, 256, 128]
dropout_rate = 0.3
learning_rate = 0.001  # Higher learning rate for training from scratch
weight_decay = 1e-5
num_epochs = 50  # More epochs for training from scratch
label_columns = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 
            'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 
            'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 
            'Pleural Other', 'Fracture', 'Support Devices', 'No Finding']

# Create fair dataset
fair_train_subset = create_fair_dataset(train_subset)

# Train a new model from scratch
model_from_scratch, training_history = train_from_scratch(
    fair_train_subset=fair_train_subset,
    val_subset=val_subset,
    input_dim=1376,
    hidden_dims=hidden_dims,
    output_dim=len(label_columns),
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    num_epochs=num_epochs,
    dropout_rate=dropout_rate
)

# Save the trained model
torch.save(model_from_scratch.state_dict(), "fair_model_from_scratch.pt")

# Adversary Models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
import time
from tqdm import tqdm
import os

# Define the neural network model
class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        """
        Simple feed-forward neural network for multi-label classification
        
        Args:
            input_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            output_dim: Number of output classes
            dropout_rate: Dropout probability for regularization
        """
        super(MIMICClassifier, self).__init__()
        
        # Create the layers
        layers = []
        
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        
        # Hidden layers
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        # No activation function here since we're using BCEWithLogitsLoss
        # which combines sigmoid and binary cross entropy
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)



In [None]:
class LinearVAE(nn.Module):
    def __init__(self, input_dim=1376, hidden_dim=512, latent_dim=128):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Use sigmoid if your embeddings are normalized between 0-1
                          # or remove if using other normalization
        )
        
    def encode(self, x):
        hidden = self.encoder(x)
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar, z
    
    def get_latent(self, x):
        """Get latent representation without reconstruction"""
        mu, logvar = self.encode(x)
        return self.reparameterize(mu, logvar)

def VAE_LOSS(reconstructed, x, mu, logvar, kld_weight=0.005):
    """
    VAE loss with KL divergence and reconstruction loss
    
    Args:
        reconstructed: Reconstructed input from decoder
        x: Original input 
        mu: Mean from encoder
        logvar: Log variance from encoder
        kld_weight: Weight for KL divergence term
    """
    # Reconstruction loss (MSE or BCE depending on your data)
    recon_loss = F.mse_loss(reconstructed, x, reduction='sum')
    
    # KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    loss = recon_loss + kld_weight * kld_loss
    
    return loss, recon_loss, kld_loss
    
    

# Training

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset # Added Subset
import torch.nn.functional as F # Added for loss functions
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, mean_squared_error # Added MSE
from sklearn.preprocessing import LabelEncoder # To handle categorical labels
import time
from tqdm import tqdm
import os
import pickle
import copy # To copy models if needed


# Path to processed data
data_path = "/kaggle/input/mimic-embedding/processed_mimic_data"

class ProcessedMIMICDataset(Dataset):
    def __init__(self, data_path, data_format='pt', demographic_mappings=None):
        self.data_path = data_path
        self.data_format = data_format
        self.demographic_mappings = demographic_mappings if demographic_mappings else {}

        # --- Load data based on format ---
        if data_format == 'pkl':
            with open(os.path.join(data_path, "all_data.pkl"), 'rb') as f:
                self.all_data = pickle.load(f)
            self.embeddings = self.all_data['embeddings']
            self.labels = self.all_data['labels']
            self.subject_ids = self.all_data['subject_ids']
            self.study_ids = self.all_data['study_ids']
            self.demographics = self.all_data['demographics']

        elif data_format == 'pt':
            self.embeddings = torch.load(os.path.join(data_path, "embeddings.pt"), map_location='cpu') # Use map_location for flexibility
            self.labels = torch.load(os.path.join(data_path, "labels.pt"), map_location='cpu')
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                # Demographics is often a list of dictionaries
                self.demographics_raw = pickle.load(f)
                # Pre-process demographics into a more usable format if needed
                # For now, assume it's a list of dicts, one per sample

        elif data_format == 'npy':
            self.embeddings = np.load(os.path.join(data_path, "embeddings.npy"))
            self.labels = np.load(os.path.join(data_path, "labels.npy"))
            ids_df = pd.read_csv(os.path.join(data_path, "ids.csv"))
            self.subject_ids = ids_df['subject_id'].tolist()
            self.study_ids = ids_df['study_id'].tolist()
            with open(os.path.join(data_path, "demographics.pkl"), 'rb') as f:
                self.demographics_raw = pickle.load(f)

        else:
            raise ValueError(f"Unsupported data format: {data_format}")

        print(f"Loaded dataset from {data_path} with {len(self.embeddings)} samples")

        # --- Precompute demographic mappings if not provided ---
        if not self.demographic_mappings:
            print("Computing demographic mappings...")
            all_genders = [d['gender'] for d in self.demographics_raw]
            all_insurances = [d['insurance'] for d in self.demographics_raw]
            all_races = [d['race'] for d in self.demographics_raw]

            self.demographic_mappings['gender'] = {label: i for i, label in enumerate(sorted(pd.Series(all_genders).unique()))}
            self.demographic_mappings['insurance'] = {label: i for i, label in enumerate(sorted(pd.Series(all_insurances).unique()))}
            self.demographic_mappings['race'] = {label: i for i, label in enumerate(sorted(pd.Series(all_races).unique()))}
            print("Mappings computed:", self.demographic_mappings)


    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        if self.data_format == 'pt':
            embedding = self.embeddings[idx]
            labels = self.labels[idx]
        else:
            embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
            labels = torch.tensor(self.labels[idx], dtype=torch.float32) # Assuming labels are multi-label float

        # Process demographics for the current item
        demo_raw = self.demographics_raw[idx]
        demographics_processed = {
            'gender': torch.tensor(self.demographic_mappings['gender'].get(demo_raw['gender'], -1), dtype=torch.long), # Use .get for safety
            'insurance': torch.tensor(self.demographic_mappings['insurance'].get(demo_raw['insurance'], -1), dtype=torch.long),
            'race': torch.tensor(self.demographic_mappings['race'].get(demo_raw['race'], -1), dtype=torch.long),
            'anchor_age': torch.tensor(demo_raw['anchor_age'], dtype=torch.float32)
        }

        # Handle potential missing values if necessary (e.g., if get returned -1)

        return {
            'embedding': embedding,
            'labels': labels,
            'subject_id': self.subject_ids[idx],
            'study_id': self.study_ids[idx],
            'demographics': demographics_processed # Return processed demographics
        }

# --- Define Models ---

class MIMICClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        super(MIMICClassifier, self).__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(nn.Dropout(dropout_rate))
        for i in range(len(hidden_dims)-1):
            layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
            layers.append(nn.Dropout(dropout_rate))
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class LinearVAE(nn.Module):
    def __init__(self, input_dim=1376, hidden_dim=512, latent_dim=128):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            # nn.Sigmoid() # Sigmoid removed - usually better without for embeddings unless strictly 0-1 normalized
        )

    def encode(self, x):
        hidden = self.encoder(x)
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim)) # Ensure correct shape
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar, z

    def get_latent(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        return self.reparameterize(mu, logvar)

    def get_reconstructed(self, x):
         mu, logvar = self.encode(x.view(-1, self.input_dim))
         z = self.reparameterize(mu, logvar)
         return self.decode(z)


class Adversary(nn.Module):
    """ Predicts demographics from latent space """
    def __init__(self, latent_dim=128, hidden_dim=256, num_genders=2, num_insurances=3, num_races=7):
        super().__init__()
        self.layer1 = nn.Linear(latent_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3) # Slightly higher dropout for adversary

        # Output heads for each demographic attribute
        self.gender_head = nn.Linear(hidden_dim, num_genders) # Use num_genders (usually 2)
        self.insurance_head = nn.Linear(hidden_dim, num_insurances)
        self.race_head = nn.Linear(hidden_dim, num_races)
        self.age_head = nn.Linear(hidden_dim, 1) # Regression for age

    def forward(self, z):
        shared = self.dropout(self.relu(self.bn1(self.layer1(z))))
        gender_pred = self.gender_head(shared)
        insurance_pred = self.insurance_head(shared)
        race_pred = self.race_head(shared)
        age_pred = self.age_head(shared)

        return {
            'gender': gender_pred,
            'insurance': insurance_pred,
            'race': race_pred,
        }

# --- Loss Functions ---

def VAE_LOSS(reconstructed, x, mu, logvar, kld_weight=1.0): # Adjusted default kld_weight
    """ Calculate VAE Loss (Reconstruction + KL Divergence) """
    batch_size = x.size(0)
    # Ensure shapes match
    x_flat = x.view(batch_size, -1)
    reconstructed_flat = reconstructed.view(batch_size, -1)

    recon_loss = F.mse_loss(reconstructed_flat, x_flat, reduction='sum') / batch_size # Per sample avg
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size # Per sample avg

    loss = recon_loss + kld_weight * kld_loss
    return loss, recon_loss, kld_loss

def ADVERSARY_LOSS(adv_preds, true_demographics, device):
    """ Calculate loss for the adversary """
    # Move true labels to the correct device
    true_gender = true_demographics['gender'].to(device)
    true_insurance = true_demographics['insurance'].to(device)
    true_race = true_demographics['race'].to(device)
    true_age = true_demographics['anchor_age'].to(device)

    # --- Handle potential missing labels (-1) ---
    # We will calculate loss only for valid labels
    valid_gender_mask = true_gender != -1
    valid_insurance_mask = true_insurance != -1
    valid_race_mask = true_race != -1
    # Assuming age is always present, if not add mask

    loss_gender = F.cross_entropy(adv_preds['gender'][valid_gender_mask], true_gender[valid_gender_mask]) if valid_gender_mask.any() else torch.tensor(0.0).to(device)
    loss_insurance = F.cross_entropy(adv_preds['insurance'][valid_insurance_mask], true_insurance[valid_insurance_mask]) if valid_insurance_mask.any() else torch.tensor(0.0).to(device)
    loss_race = F.cross_entropy(adv_preds['race'][valid_race_mask], true_race[valid_race_mask]) if valid_race_mask.any() else torch.tensor(0.0).to(device)
    # Combine losses (can add weights here if needed)
    total_loss = loss_gender + loss_insurance + loss_race

    return total_loss, {'gender': loss_gender, 'insurance': loss_insurance, 'race': loss_race}

# --- Training Functions ---

def train_vae_adversarial(vae, adversary, train_loader, val_loader, vae_optimizer, adv_optimizer,
                          num_epochs, device, kld_weight=1.0, adversary_weight=5.0, # adversary_weight > 1 typically needed
                          log_interval=100, save_path='models'):
    """ Trains the VAE with adversarial debiasing """
    vae.to(device)
    adversary.to(device)
    os.makedirs(save_path, exist_ok=True)
    best_val_loss = float('inf')

    print(f"Starting adversarial VAE training for {num_epochs} epochs...")
    print(f"KLD Weight: {kld_weight}, Adversary Weight: {adversary_weight}")

    for epoch in range(num_epochs):
        vae.train()
        adversary.train()
        total_vae_loss_epoch = 0.0
        total_recon_loss_epoch = 0.0
        total_kld_loss_epoch = 0.0
        total_adv_loss_epoch = 0.0
        start_time = time.time()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch_idx, batch in enumerate(progress_bar):
            embeddings = batch['embedding'].to(device)
            demographics = batch['demographics'] # Keep dict on CPU until needed in loss

            # --- Adversary Training Step ---
            adv_optimizer.zero_grad()
            with torch.no_grad(): # Get latent space without tracking VAE grads
                 _, mu, logvar, z = vae(embeddings)
            # Detach z so VAE encoder is not updated during adversary's step
            adv_preds = adversary(z.detach())
            adv_loss, _ = ADVERSARY_LOSS(adv_preds, demographics, device)
            adv_loss.backward()
            adv_optimizer.step()
            total_adv_loss_epoch += adv_loss.item()

            # --- VAE Training Step ---
            vae_optimizer.zero_grad()
            reconstructed, mu, logvar, z_vae = vae(embeddings) # Forward pass through VAE
            # Calculate VAE's reconstruction and KLD loss
            vae_loss, recon_loss, kld_loss = VAE_LOSS(reconstructed, embeddings, mu, logvar, kld_weight)
            # Calculate Adversary loss for the VAE's objective (fooling the adversary)
            # IMPORTANT: Use z_vae (which tracks grads back to encoder) here
            adv_preds_for_vae = adversary(z_vae)
            adv_loss_for_vae, _ = ADVERSARY_LOSS(adv_preds_for_vae, demographics, device)
            # VAE aims to minimize reconstruction/KLD AND maximize adversary loss
            combined_vae_loss = vae_loss - adversary_weight * adv_loss_for_vae # Note the minus sign!
            combined_vae_loss.backward()
            vae_optimizer.step()

            total_vae_loss_epoch += combined_vae_loss.item()
            total_recon_loss_epoch += recon_loss.item()
            total_kld_loss_epoch += kld_loss.item()

            if batch_idx % log_interval == 0:
                 progress_bar.set_postfix({
                     'VAE Loss': f"{combined_vae_loss.item():.4f}",
                     'Recon Loss': f"{recon_loss.item():.4f}",
                     'KLD Loss': f"{kld_loss.item():.4f}",
                     'Adv Loss': f"{adv_loss.item():.4f}"
                 })

        # --- Validation ---
        vae.eval()
        adversary.eval()
        val_vae_loss = 0.0
        val_adv_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                embeddings = batch['embedding'].to(device)
                demographics = batch['demographics']
                reconstructed, mu, logvar, z = vae(embeddings)
                vae_loss, _, _ = VAE_LOSS(reconstructed, embeddings, mu, logvar, kld_weight)
                adv_preds = adversary(z)
                adv_loss, _ = ADVERSARY_LOSS(adv_preds, demographics, device)

                # For validation, we often care about the non-adversarial VAE loss
                # Or track both VAE loss and adversary performance separately
                val_vae_loss += vae_loss.item() # Track basic VAE loss for model saving
                val_adv_loss += adv_loss.item()

        avg_train_vae_loss = total_vae_loss_epoch / len(train_loader)
        avg_train_adv_loss = total_adv_loss_epoch / len(train_loader)
        avg_val_vae_loss = val_vae_loss / len(val_loader)
        avg_val_adv_loss = val_adv_loss / len(val_loader)
        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s")
        print(f"  Train VAE Loss: {avg_train_vae_loss:.4f} | Train Adv Loss: {avg_train_adv_loss:.4f}")
        print(f"  Val VAE Loss  : {avg_val_vae_loss:.4f} | Val Adv Loss  : {avg_val_adv_loss:.4f}")

        # Save best model based on validation VAE loss (reconstruction focus)
        if avg_val_vae_loss < best_val_loss:
            best_val_loss = avg_val_vae_loss
            torch.save(vae.state_dict(), os.path.join(save_path, 'best_vae_adversarial.pt'))
            torch.save(adversary.state_dict(), os.path.join(save_path, 'best_adversary.pt'))
            print(f"  Saved best models with Val VAE Loss: {best_val_loss:.4f}")

    print("Adversarial VAE training finished.")
    # Load best models for returning
    vae.load_state_dict(torch.load(os.path.join(save_path, 'best_vae_adversarial.pt')))
    adversary.load_state_dict(torch.load(os.path.join(save_path, 'best_adversary.pt')))
    return vae, adversary


def train_classifier(classifier, train_loader, val_loader, criterion, optimizer, num_epochs, device,
                     log_interval=100, save_path='models', model_name='best_classifier.pt',
                     use_debiased_embeddings=False, vae=None): # Added options for debiased input
    """ Trains the downstream classifier """
    classifier.to(device)
    if use_debiased_embeddings and vae is not None:
        vae.to(device)
        vae.eval() # VAE should be frozen if used for debiasing input
        print("Training classifier on DEBIASED embeddings.")
    else:
        print("Training classifier on ORIGINAL embeddings.")

    os.makedirs(save_path, exist_ok=True)
    best_val_auc = 0.0

    print(f"Starting classifier training for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        classifier.train()
        total_loss_epoch = 0.0
        start_time = time.time()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch_idx, batch in enumerate(progress_bar):
            embeddings = batch['embedding'].to(device)
            labels = batch['labels'].to(device) # Assuming labels are multi-label float (0 or 1)

            # Get embeddings: original or debiased
            if use_debiased_embeddings and vae is not None:
                 with torch.no_grad(): # Don't train VAE here
                    # Use reconstructed embeddings as debiased input
                    input_embeddings = vae.get_reconstructed(embeddings)
                    # Alternative: use latent space z directly if classifier input_dim matches latent_dim
                    # input_embeddings = vae.get_latent(embeddings)
            else:
                input_embeddings = embeddings

            # --- Classifier Training Step ---
            optimizer.zero_grad()
            outputs = classifier(input_embeddings)
            loss = criterion(outputs, labels) # BCEWithLogitsLoss expects raw logits
            loss.backward()
            optimizer.step()

            total_loss_epoch += loss.item()

            if batch_idx % log_interval == 0:
                 progress_bar.set_postfix({'Loss': f"{loss.item():.4f}"})

        # --- Validation ---
        avg_val_loss, avg_val_auc = evaluate_classifier(classifier, val_loader, criterion, device, use_debiased_embeddings, vae)
        avg_train_loss = total_loss_epoch / len(train_loader)
        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss  : {avg_val_loss:.4f} | Val AUC  : {avg_val_auc:.4f}")

        # Save best model based on validation AUC
        if avg_val_auc > best_val_auc:
            best_val_auc = avg_val_auc
            torch.save(classifier.state_dict(), os.path.join(save_path, model_name))
            print(f"  Saved best model with Val AUC: {best_val_auc:.4f}")

    print("Classifier training finished.")
    # Load best model
    classifier.load_state_dict(torch.load(os.path.join(save_path, model_name)))
    return classifier


def evaluate_classifier(classifier, data_loader, criterion, device, use_debiased_embeddings=False, vae=None):
    """ Evaluates the classifier on a given dataset """
    classifier.eval()
    if use_debiased_embeddings and vae is not None:
        vae.to(device)
        vae.eval()

    total_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in data_loader:
            embeddings = batch['embedding'].to(device)
            labels = batch['labels'].to(device)

            if use_debiased_embeddings and vae is not None:
                input_embeddings = vae.get_reconstructed(embeddings)
            else:
                input_embeddings = embeddings

            outputs = classifier(input_embeddings)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Store labels and predictions (probabilities) for AUC calculation
            all_labels.append(labels.cpu().numpy())
            all_preds.append(torch.sigmoid(outputs).cpu().numpy()) # Apply sigmoid to get probs

    avg_loss = total_loss / len(data_loader)

    # Concatenate results from all batches
    all_labels = np.concatenate(all_labels, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)

    # Calculate AUC (macro average over labels)
    # Handle cases where a label might not be present or have only one class
    auc_scores = []
    for i in range(all_labels.shape[1]): # Iterate over each label column
        try:
            auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
            auc_scores.append(auc)
        except ValueError:
            # Handle error (e.g., only one class present in labels) - skip or assign 0.5?
             # print(f"Warning: Cannot compute AUC for label {i}, possibly only one class present.")
             auc_scores.append(np.nan) # Append NaN, handle later

    avg_auc = np.nanmean(auc_scores) # Calculate mean ignoring NaNs

    return avg_loss, avg_auc


# --- Main Execution ---
if __name__ == "__main__":

    # --- Configuration ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    VAE_LATENT_DIM = 128 
    VAE_HIDDEN_DIM = 512
    ADV_HIDDEN_DIM = 256
    CLASSIFIER_INPUT_DIM = 1376 # Original embedding size
    CLASSIFIER_HIDDEN_DIMS = [512, 256]
    CLASSIFIER_OUTPUT_DIM = 14 # Number of labels
    EMBEDDING_DIM = 1376

    VAE_EPOCHS = 50
    CLASSIFIER_EPOCHS = 30 
    BATCH_SIZE = 128
    VAE_LR = 1e-4
    ADV_LR = 1e-4
    CLASSIFIER_LR = 1e-3
    KLD_WEIGHT = 1.0 # Weight for KL divergence in VAE loss
    ADVERSARY_WEIGHT = 10.0 # How much to penalize VAE for demographic leakage

    # --- Load Data ---
    print("Loading datasets...")
    # Load train once to get mappings
    full_train_dataset = ProcessedMIMICDataset(os.path.join(data_path, "train"), data_format='pt')
    demographic_mappings = full_train_dataset.demographic_mappings

    test_dataset = ProcessedMIMICDataset(os.path.join(data_path, "test"), data_format='pt', demographic_mappings=demographic_mappings)

    # Create validation split using Subset
    val_ratio = 0.1
    dataset_size = len(full_train_dataset)
    val_size = int(val_ratio * dataset_size)
    train_size = dataset_size - val_size
    indices = list(range(dataset_size))
    np.random.seed(42) # for reproducibility
    np.random.shuffle(indices)
    train_indices, val_indices = indices[:train_size], indices[train_size:]

    train_subset = Subset(full_train_dataset, train_indices)
    val_subset = Subset(full_train_dataset, val_indices)

    print(f"Data loaded: Train={len(train_subset)}, Val={len(val_subset)}, Test={len(test_dataset)}")

    # Create data loaders
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # --- Initialize Models ---
    print("Initializing models...")
    # Get number of unique classes for adversary heads from mappings
    num_genders = len(demographic_mappings['gender'])
    num_insurances = len(demographic_mappings['insurance'])
    num_races = len(demographic_mappings['race'])

    vae = LinearVAE(input_dim=EMBEDDING_DIM, hidden_dim=VAE_HIDDEN_DIM, latent_dim=VAE_LATENT_DIM)
    adversary = Adversary(latent_dim=VAE_LATENT_DIM, hidden_dim=ADV_HIDDEN_DIM,
                          num_genders=num_genders, num_insurances=num_insurances, num_races=num_races)
    classifier_original = MIMICClassifier(input_dim=CLASSIFIER_INPUT_DIM, hidden_dims=CLASSIFIER_HIDDEN_DIMS, output_dim=CLASSIFIER_OUTPUT_DIM)
    # Classifier for debiased embeddings might need different input dim if using latent space
    classifier_debiased = MIMICClassifier(input_dim=CLASSIFIER_INPUT_DIM, hidden_dims=CLASSIFIER_HIDDEN_DIMS, output_dim=CLASSIFIER_OUTPUT_DIM) # Assuming reconstructed embeddings

    # --- Optimizers and Criterion ---
    vae_optimizer = optim.Adam(vae.parameters(), lr=VAE_LR)
    adv_optimizer = optim.Adam(adversary.parameters(), lr=ADV_LR)
    # Using BCEWithLogitsLoss for multi-label classification is standard
    classifier_criterion = nn.BCEWithLogitsLoss()
    classifier_orig_optimizer = optim.Adam(classifier_original.parameters(), lr=CLASSIFIER_LR)
    classifier_debiased_optimizer = optim.Adam(classifier_debiased.parameters(), lr=CLASSIFIER_LR)

    # --- Train Adversarial VAE ---
    print("\n--- Training Adversarial VAE ---")
    trained_vae, trained_adversary = train_vae_adversarial(
        vae, adversary, train_loader, val_loader, vae_optimizer, adv_optimizer,
        num_epochs=VAE_EPOCHS, device=DEVICE, kld_weight=KLD_WEIGHT,
        adversary_weight=ADVERSARY_WEIGHT, save_path='models/vae_adv'
    )

    # --- Train Downstream Classifiers ---
    # 1. Train on ORIGINAL embeddings
    print("\n--- Training Classifier on ORIGINAL Embeddings ---")
    trained_classifier_original = train_classifier(
        classifier_original, train_loader, val_loader, classifier_criterion, classifier_orig_optimizer,
        num_epochs=CLASSIFIER_EPOCHS, device=DEVICE, save_path='models/classifier',
        model_name='best_classifier_original.pt', use_debiased_embeddings=False
    )

    # 2. Train on DEBIASED (reconstructed) embeddings
    print("\n--- Training Classifier on DEBIASED Embeddings ---")
    trained_classifier_debiased = train_classifier(
        classifier_debiased, train_loader, val_loader, classifier_criterion, classifier_debiased_optimizer,
        num_epochs=CLASSIFIER_EPOCHS, device=DEVICE, save_path='models/classifier',
        model_name='best_classifier_debiased.pt', use_debiased_embeddings=True, vae=trained_vae # Pass the trained VAE
    )

    # --- Final Evaluation on Test Set ---
    print("\n--- Evaluating on Test Set ---")

    # Evaluate Original Classifier
    test_loss_orig, test_auc_orig = evaluate_classifier(
        trained_classifier_original, test_loader, classifier_criterion, DEVICE,
        use_debiased_embeddings=False
    )
    print(f"Classifier (Original Embeddings) Test Loss: {test_loss_orig:.4f}, Test AUC: {test_auc_orig:.4f}")

    # Evaluate Debiased Classifier
    test_loss_debiased, test_auc_debiased = evaluate_classifier(
        trained_classifier_debiased, test_loader, classifier_criterion, DEVICE,
        use_debiased_embeddings=True, vae=trained_vae
    )
    print(f"Classifier (Debiased Embeddings) Test Loss: {test_loss_debiased:.4f}, Test AUC: {test_auc_debiased:.4f}")

    # Optional: Evaluate how well the adversary predicts demographics on the test set
    # (using the latent space from the trained VAE)
    trained_vae.eval()
    trained_adversary.eval()
    test_adv_loss_total = 0
    all_true_demos = {'gender': [], 'insurance': [], 'race': []}
    all_pred_demos = {'gender': [], 'insurance': [], 'race': []}

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating Adversary on Test Set", leave=False):
            embeddings = batch['embedding'].to(DEVICE)
            demographics = batch['demographics']
            _, _, _, z = trained_vae(embeddings)
            adv_preds = trained_adversary(z)
            adv_loss, adv_losses_dict = ADVERSARY_LOSS(adv_preds, demographics, DEVICE)
            test_adv_loss_total += adv_loss.item()

            # Store true and predicted values for detailed analysis (accuracy, MSE etc.)
            for key in all_true_demos.keys():
                 valid_mask = demographics[key] != -1 # Use the same masking logic
                 if valid_mask.any():
                    all_true_demos[key].append(demographics[key][valid_mask].cpu().numpy())
                    # Get predicted class index
                    pred_classes = torch.argmax(adv_preds[key][valid_mask], dim=1)
                    all_pred_demos[key].append(pred_classes.cpu().numpy())


    avg_test_adv_loss = test_adv_loss_total / len(test_loader)
    print(f"\nAdversary Final Test Loss: {avg_test_adv_loss:.4f}")

    # Calculate and print accuracy/MSE for demographics
    for key in all_true_demos.keys():
        if not all_true_demos[key]: continue # Skip if no valid data for this demographic
        true_vals = np.concatenate(all_true_demos[key])
        pred_vals = np.concatenate(all_pred_demos[key])
        accuracy = np.mean(true_vals == pred_vals)
        print(f"  Adversary {key.capitalize()} Prediction Test Accuracy: {accuracy:.4f}")

    print("\n--- Script Finished ---")

Using device: cuda
Loading datasets...


  self.embeddings = torch.load(os.path.join(data_path, "embeddings.pt"), map_location='cpu') # Use map_location for flexibility
  self.labels = torch.load(os.path.join(data_path, "labels.pt"), map_location='cpu')


Loaded dataset from /kaggle/input/mimic-embedding/processed_mimic_data/train with 207314 samples
Computing demographic mappings...
Mappings computed: {'gender': {'F': 0, 'M': 1}, 'insurance': {'Medicaid': 0, 'Medicare': 1, 'Other': 2}, 'race': {'AMERICAN INDIAN/ALASKA NATIVE': 0, 'ASIAN': 1, 'BLACK/AFRICAN AMERICAN': 2, 'HISPANIC/LATINO': 3, 'OTHER': 4, 'UNABLE TO OBTAIN': 5, 'UNKNOWN': 6, 'WHITE': 7}}
Loaded dataset from /kaggle/input/mimic-embedding/processed_mimic_data/test with 21591 samples


  self.embeddings = torch.load(os.path.join(data_path, "embeddings.pt"), map_location='cpu') # Use map_location for flexibility
  self.labels = torch.load(os.path.join(data_path, "labels.pt"), map_location='cpu')


Data loaded: Train=186583, Val=20731, Test=21591
Initializing models...

--- Training Adversarial VAE ---
Starting adversarial VAE training for 50 epochs...
KLD Weight: 1.0, Adversary Weight: 10.0


                                                                                                                                           

Epoch 1 Summary | Time: 22.88s
  Train VAE Loss: 439.6012 | Train Adv Loss: 2.8074
  Val VAE Loss  : 257.2622 | Val Adv Loss  : 2.6589
  Saved best models with Val VAE Loss: 257.2622


                                                                                                                                          

Epoch 2 Summary | Time: 22.81s
  Train VAE Loss: 219.5490 | Train Adv Loss: 2.6666
  Val VAE Loss  : 229.5839 | Val Adv Loss  : 2.6245
  Saved best models with Val VAE Loss: 229.5839


                                                                                                                                          

Epoch 3 Summary | Time: 22.58s
  Train VAE Loss: 198.1468 | Train Adv Loss: 2.6385
  Val VAE Loss  : 210.7411 | Val Adv Loss  : 2.6101
  Saved best models with Val VAE Loss: 210.7411


                                                                                                                                          

Epoch 4 Summary | Time: 22.84s
  Train VAE Loss: 182.0475 | Train Adv Loss: 2.6288
  Val VAE Loss  : 196.8231 | Val Adv Loss  : 2.6117
  Saved best models with Val VAE Loss: 196.8231


                                                                                                                                          

Epoch 5 Summary | Time: 22.91s
  Train VAE Loss: 170.1875 | Train Adv Loss: 2.6240
  Val VAE Loss  : 186.6139 | Val Adv Loss  : 2.6036
  Saved best models with Val VAE Loss: 186.6139


                                                                                                                                          

Epoch 6 Summary | Time: 22.58s
  Train VAE Loss: 161.8059 | Train Adv Loss: 2.6205
  Val VAE Loss  : 179.6211 | Val Adv Loss  : 2.5953
  Saved best models with Val VAE Loss: 179.6211


                                                                                                                                          

Epoch 7 Summary | Time: 22.79s
  Train VAE Loss: 155.7518 | Train Adv Loss: 2.6135
  Val VAE Loss  : 174.1435 | Val Adv Loss  : 2.5854
  Saved best models with Val VAE Loss: 174.1435


                                                                                                                                          

Epoch 8 Summary | Time: 22.70s
  Train VAE Loss: 151.6310 | Train Adv Loss: 2.6080
  Val VAE Loss  : 170.4801 | Val Adv Loss  : 2.5801
  Saved best models with Val VAE Loss: 170.4801


                                                                                                                                          

Epoch 9 Summary | Time: 22.70s
  Train VAE Loss: 148.4655 | Train Adv Loss: 2.6044
  Val VAE Loss  : 168.0477 | Val Adv Loss  : 2.5741
  Saved best models with Val VAE Loss: 168.0477


                                                                                                                                           

Epoch 10 Summary | Time: 22.75s
  Train VAE Loss: 146.0656 | Train Adv Loss: 2.6024
  Val VAE Loss  : 166.1689 | Val Adv Loss  : 2.5781
  Saved best models with Val VAE Loss: 166.1689


                                                                                                                                           

Epoch 11 Summary | Time: 22.93s
  Train VAE Loss: 144.1415 | Train Adv Loss: 2.5974
  Val VAE Loss  : 164.5329 | Val Adv Loss  : 2.5720
  Saved best models with Val VAE Loss: 164.5329


                                                                                                                                           

Epoch 12 Summary | Time: 22.61s
  Train VAE Loss: 142.6411 | Train Adv Loss: 2.5965
  Val VAE Loss  : 162.9445 | Val Adv Loss  : 2.5711
  Saved best models with Val VAE Loss: 162.9445


                                                                                                                                           

Epoch 13 Summary | Time: 22.69s
  Train VAE Loss: 141.3231 | Train Adv Loss: 2.5933
  Val VAE Loss  : 162.4651 | Val Adv Loss  : 2.5693
  Saved best models with Val VAE Loss: 162.4651


                                                                                                                                           

Epoch 14 Summary | Time: 22.79s
  Train VAE Loss: 140.2041 | Train Adv Loss: 2.5903
  Val VAE Loss  : 161.2158 | Val Adv Loss  : 2.5652
  Saved best models with Val VAE Loss: 161.2158


                                                                                                                                           

Epoch 15 Summary | Time: 22.93s
  Train VAE Loss: 139.3136 | Train Adv Loss: 2.5902
  Val VAE Loss  : 160.2457 | Val Adv Loss  : 2.5611
  Saved best models with Val VAE Loss: 160.2457


                                                                                                                                           

Epoch 16 Summary | Time: 22.74s
  Train VAE Loss: 138.5566 | Train Adv Loss: 2.5871
  Val VAE Loss  : 159.8294 | Val Adv Loss  : 2.5644
  Saved best models with Val VAE Loss: 159.8294


                                                                                                                                           

Epoch 17 Summary | Time: 22.65s
  Train VAE Loss: 137.8285 | Train Adv Loss: 2.5868
  Val VAE Loss  : 159.4091 | Val Adv Loss  : 2.5603
  Saved best models with Val VAE Loss: 159.4091


                                                                                                                                           

Epoch 18 Summary | Time: 22.82s
  Train VAE Loss: 137.1714 | Train Adv Loss: 2.5829
  Val VAE Loss  : 158.1226 | Val Adv Loss  : 2.5546
  Saved best models with Val VAE Loss: 158.1226


                                                                                                                                           

Epoch 19 Summary | Time: 22.65s
  Train VAE Loss: 136.6484 | Train Adv Loss: 2.5817
  Val VAE Loss  : 157.9289 | Val Adv Loss  : 2.5542
  Saved best models with Val VAE Loss: 157.9289


                                                                                                                                           

Epoch 20 Summary | Time: 22.53s
  Train VAE Loss: 136.1526 | Train Adv Loss: 2.5779
  Val VAE Loss  : 157.2107 | Val Adv Loss  : 2.5525
  Saved best models with Val VAE Loss: 157.2107


                                                                                                                                           

Epoch 21 Summary | Time: 22.73s
  Train VAE Loss: 135.6436 | Train Adv Loss: 2.5754
  Val VAE Loss  : 156.6194 | Val Adv Loss  : 2.5515
  Saved best models with Val VAE Loss: 156.6194


                                                                                                                                           

Epoch 22 Summary | Time: 22.83s
  Train VAE Loss: 135.3284 | Train Adv Loss: 2.5734
  Val VAE Loss  : 156.0278 | Val Adv Loss  : 2.5428
  Saved best models with Val VAE Loss: 156.0278


                                                                                                                                           

Epoch 23 Summary | Time: 22.86s
  Train VAE Loss: 135.0027 | Train Adv Loss: 2.5713
  Val VAE Loss  : 156.3340 | Val Adv Loss  : 2.5416


                                                                                                                                           

Epoch 24 Summary | Time: 22.84s
  Train VAE Loss: 134.5364 | Train Adv Loss: 2.5670
  Val VAE Loss  : 155.6024 | Val Adv Loss  : 2.5377
  Saved best models with Val VAE Loss: 155.6024


                                                                                                                                           

Epoch 25 Summary | Time: 22.75s
  Train VAE Loss: 134.1848 | Train Adv Loss: 2.5662
  Val VAE Loss  : 155.2099 | Val Adv Loss  : 2.5387
  Saved best models with Val VAE Loss: 155.2099


                                                                                                                                           

Epoch 26 Summary | Time: 22.89s
  Train VAE Loss: 133.8772 | Train Adv Loss: 2.5640
  Val VAE Loss  : 155.6491 | Val Adv Loss  : 2.5360


                                                                                                                                           

Epoch 27 Summary | Time: 22.70s
  Train VAE Loss: 133.6078 | Train Adv Loss: 2.5628
  Val VAE Loss  : 155.2398 | Val Adv Loss  : 2.5333


                                                                                                                                           

Epoch 28 Summary | Time: 22.97s
  Train VAE Loss: 133.2900 | Train Adv Loss: 2.5610
  Val VAE Loss  : 154.3463 | Val Adv Loss  : 2.5288
  Saved best models with Val VAE Loss: 154.3463


                                                                                                                                           

Epoch 29 Summary | Time: 22.81s
  Train VAE Loss: 132.9565 | Train Adv Loss: 2.5594
  Val VAE Loss  : 154.2978 | Val Adv Loss  : 2.5314
  Saved best models with Val VAE Loss: 154.2978


                                                                                                                                           

Epoch 30 Summary | Time: 22.48s
  Train VAE Loss: 132.7996 | Train Adv Loss: 2.5594
  Val VAE Loss  : 154.5430 | Val Adv Loss  : 2.5288


                                                                                                                                           

Epoch 31 Summary | Time: 22.70s
  Train VAE Loss: 132.4689 | Train Adv Loss: 2.5579
  Val VAE Loss  : 153.8357 | Val Adv Loss  : 2.5272
  Saved best models with Val VAE Loss: 153.8357


                                                                                                                                           

Epoch 32 Summary | Time: 22.77s
  Train VAE Loss: 132.1838 | Train Adv Loss: 2.5588
  Val VAE Loss  : 154.3584 | Val Adv Loss  : 2.5223


                                                                                                                                           

Epoch 33 Summary | Time: 22.62s
  Train VAE Loss: 131.9648 | Train Adv Loss: 2.5575
  Val VAE Loss  : 153.9225 | Val Adv Loss  : 2.5239


                                                                                                                                           

Epoch 34 Summary | Time: 22.55s
  Train VAE Loss: 131.7948 | Train Adv Loss: 2.5562
  Val VAE Loss  : 153.0954 | Val Adv Loss  : 2.5288
  Saved best models with Val VAE Loss: 153.0954


                                                                                                                                           

Epoch 35 Summary | Time: 22.82s
  Train VAE Loss: 131.5300 | Train Adv Loss: 2.5569
  Val VAE Loss  : 153.5816 | Val Adv Loss  : 2.5307


                                                                                                                                           

Epoch 36 Summary | Time: 22.77s
  Train VAE Loss: 131.3298 | Train Adv Loss: 2.5566
  Val VAE Loss  : 152.5787 | Val Adv Loss  : 2.5257
  Saved best models with Val VAE Loss: 152.5787


                                                                                                                                           

Epoch 37 Summary | Time: 23.12s
  Train VAE Loss: 131.0881 | Train Adv Loss: 2.5565
  Val VAE Loss  : 152.8988 | Val Adv Loss  : 2.5238


                                                                                                                                           

Epoch 38 Summary | Time: 22.78s
  Train VAE Loss: 130.9157 | Train Adv Loss: 2.5579
  Val VAE Loss  : 152.6945 | Val Adv Loss  : 2.5247


                                                                                                                                           

Epoch 39 Summary | Time: 22.77s
  Train VAE Loss: 130.8138 | Train Adv Loss: 2.5554
  Val VAE Loss  : 152.6319 | Val Adv Loss  : 2.5237


                                                                                                                                           

Epoch 40 Summary | Time: 22.80s
  Train VAE Loss: 130.6246 | Train Adv Loss: 2.5557
  Val VAE Loss  : 152.1912 | Val Adv Loss  : 2.5244
  Saved best models with Val VAE Loss: 152.1912


                                                                                                                                           

Epoch 41 Summary | Time: 22.55s
  Train VAE Loss: 130.4412 | Train Adv Loss: 2.5557
  Val VAE Loss  : 151.7469 | Val Adv Loss  : 2.5227
  Saved best models with Val VAE Loss: 151.7469


                                                                                                                                           

Epoch 42 Summary | Time: 22.79s
  Train VAE Loss: 130.2314 | Train Adv Loss: 2.5566
  Val VAE Loss  : 151.8168 | Val Adv Loss  : 2.5244


                                                                                                                                           

Epoch 43 Summary | Time: 22.76s
  Train VAE Loss: 130.0816 | Train Adv Loss: 2.5559
  Val VAE Loss  : 151.7312 | Val Adv Loss  : 2.5265
  Saved best models with Val VAE Loss: 151.7312


                                                                                                                                           

Epoch 44 Summary | Time: 22.45s
  Train VAE Loss: 130.0748 | Train Adv Loss: 2.5559
  Val VAE Loss  : 151.7638 | Val Adv Loss  : 2.5264


                                                                                                                                           

Epoch 45 Summary | Time: 22.59s
  Train VAE Loss: 129.8252 | Train Adv Loss: 2.5558
  Val VAE Loss  : 151.5553 | Val Adv Loss  : 2.5236
  Saved best models with Val VAE Loss: 151.5553


                                                                                                                                           

Epoch 46 Summary | Time: 22.80s
  Train VAE Loss: 129.6323 | Train Adv Loss: 2.5541
  Val VAE Loss  : 151.0754 | Val Adv Loss  : 2.5230
  Saved best models with Val VAE Loss: 151.0754


                                                                                                                                           

Epoch 47 Summary | Time: 22.66s
  Train VAE Loss: 129.5246 | Train Adv Loss: 2.5539
  Val VAE Loss  : 151.0880 | Val Adv Loss  : 2.5226


                                                                                                                                           

Epoch 48 Summary | Time: 22.62s
  Train VAE Loss: 129.4947 | Train Adv Loss: 2.5532
  Val VAE Loss  : 151.0502 | Val Adv Loss  : 2.5244
  Saved best models with Val VAE Loss: 151.0502


                                                                                                                                           

Epoch 49 Summary | Time: 22.60s
  Train VAE Loss: 129.2381 | Train Adv Loss: 2.5531
  Val VAE Loss  : 151.2663 | Val Adv Loss  : 2.5312


  vae.load_state_dict(torch.load(os.path.join(save_path, 'best_vae_adversarial.pt')))
  adversary.load_state_dict(torch.load(os.path.join(save_path, 'best_adversary.pt')))


Epoch 50 Summary | Time: 23.08s
  Train VAE Loss: 129.1322 | Train Adv Loss: 2.5536
  Val VAE Loss  : 150.6480 | Val Adv Loss  : 2.5209
  Saved best models with Val VAE Loss: 150.6480
Adversarial VAE training finished.

--- Training Classifier on ORIGINAL Embeddings ---
Training classifier on ORIGINAL embeddings.
Starting classifier training for 30 epochs...


                                                                             

Epoch 1 Summary | Time: 14.05s
  Train Loss: 0.2821
  Val Loss  : 0.2546 | Val AUC  : 0.8143
  Saved best model with Val AUC: 0.8143


                                                                             

Epoch 2 Summary | Time: 13.79s
  Train Loss: 0.2552
  Val Loss  : 0.2502 | Val AUC  : 0.8232
  Saved best model with Val AUC: 0.8232


                                                                             

Epoch 3 Summary | Time: 14.16s
  Train Loss: 0.2533
  Val Loss  : 0.2496 | Val AUC  : 0.8242
  Saved best model with Val AUC: 0.8242


                                                                             

Epoch 4 Summary | Time: 14.10s
  Train Loss: 0.2522
  Val Loss  : 0.2489 | Val AUC  : 0.8281
  Saved best model with Val AUC: 0.8281


                                                                             

Epoch 5 Summary | Time: 14.46s
  Train Loss: 0.2513
  Val Loss  : 0.2508 | Val AUC  : 0.8288
  Saved best model with Val AUC: 0.8288


                                                                             

Epoch 6 Summary | Time: 13.99s
  Train Loss: 0.2506
  Val Loss  : 0.2474 | Val AUC  : 0.8310
  Saved best model with Val AUC: 0.8310


                                                                             

Epoch 7 Summary | Time: 14.23s
  Train Loss: 0.2502
  Val Loss  : 0.2478 | Val AUC  : 0.8305


                                                                             

Epoch 8 Summary | Time: 13.66s
  Train Loss: 0.2497
  Val Loss  : 0.2478 | Val AUC  : 0.8308


                                                                             

Epoch 9 Summary | Time: 13.89s
  Train Loss: 0.2492
  Val Loss  : 0.2487 | Val AUC  : 0.8300


                                                                              

Epoch 10 Summary | Time: 14.09s
  Train Loss: 0.2488
  Val Loss  : 0.2462 | Val AUC  : 0.8315
  Saved best model with Val AUC: 0.8315


                                                                              

Epoch 11 Summary | Time: 13.88s
  Train Loss: 0.2485
  Val Loss  : 0.2505 | Val AUC  : 0.8322
  Saved best model with Val AUC: 0.8322


                                                                              

Epoch 12 Summary | Time: 13.85s
  Train Loss: 0.2483
  Val Loss  : 0.2472 | Val AUC  : 0.8329
  Saved best model with Val AUC: 0.8329


                                                                              

Epoch 13 Summary | Time: 13.54s
  Train Loss: 0.2481
  Val Loss  : 0.2459 | Val AUC  : 0.8342
  Saved best model with Val AUC: 0.8342


                                                                              

Epoch 14 Summary | Time: 13.85s
  Train Loss: 0.2479
  Val Loss  : 0.2472 | Val AUC  : 0.8329


                                                                              

Epoch 15 Summary | Time: 14.15s
  Train Loss: 0.2479
  Val Loss  : 0.2466 | Val AUC  : 0.8339


                                                                              

Epoch 16 Summary | Time: 13.91s
  Train Loss: 0.2476
  Val Loss  : 0.3052 | Val AUC  : 0.8336


                                                                              

Epoch 17 Summary | Time: 13.66s
  Train Loss: 0.2475
  Val Loss  : 0.2476 | Val AUC  : 0.8351
  Saved best model with Val AUC: 0.8351


                                                                              

Epoch 18 Summary | Time: 14.24s
  Train Loss: 0.2473
  Val Loss  : 0.2468 | Val AUC  : 0.8351
  Saved best model with Val AUC: 0.8351


                                                                              

Epoch 19 Summary | Time: 14.16s
  Train Loss: 0.2468
  Val Loss  : 0.2478 | Val AUC  : 0.8339


                                                                              

Epoch 20 Summary | Time: 13.96s
  Train Loss: 0.2469
  Val Loss  : 0.2471 | Val AUC  : 0.8346


                                                                              

Epoch 21 Summary | Time: 14.04s
  Train Loss: 0.2467
  Val Loss  : 0.2486 | Val AUC  : 0.8352
  Saved best model with Val AUC: 0.8352


                                                                              

Epoch 22 Summary | Time: 13.37s
  Train Loss: 0.2465
  Val Loss  : 0.2455 | Val AUC  : 0.8352
  Saved best model with Val AUC: 0.8352


                                                                              

Epoch 23 Summary | Time: 13.18s
  Train Loss: 0.2465
  Val Loss  : 0.2456 | Val AUC  : 0.8344


                                                                              

Epoch 24 Summary | Time: 13.28s
  Train Loss: 0.2461
  Val Loss  : 0.2466 | Val AUC  : 0.8339


                                                                              

Epoch 25 Summary | Time: 13.14s
  Train Loss: 0.2460
  Val Loss  : 0.2460 | Val AUC  : 0.8348


                                                                              

Epoch 26 Summary | Time: 13.11s
  Train Loss: 0.2459
  Val Loss  : 0.2455 | Val AUC  : 0.8361
  Saved best model with Val AUC: 0.8361


                                                                              

Epoch 27 Summary | Time: 13.05s
  Train Loss: 0.2456
  Val Loss  : 0.2451 | Val AUC  : 0.8358


                                                                              

Epoch 28 Summary | Time: 13.30s
  Train Loss: 0.2456
  Val Loss  : 0.2456 | Val AUC  : 0.8348


                                                                              

Epoch 29 Summary | Time: 13.28s
  Train Loss: 0.2454
  Val Loss  : 0.2454 | Val AUC  : 0.8355


  classifier.load_state_dict(torch.load(os.path.join(save_path, model_name)))


Epoch 30 Summary | Time: 12.87s
  Train Loss: 0.2454
  Val Loss  : 0.2457 | Val AUC  : 0.8356
Classifier training finished.

--- Training Classifier on DEBIASED Embeddings ---
Training classifier on DEBIASED embeddings.
Starting classifier training for 30 epochs...


                                                                             

Epoch 1 Summary | Time: 14.06s
  Train Loss: 0.2852
  Val Loss  : 0.2569 | Val AUC  : 0.8055
  Saved best model with Val AUC: 0.8055


                                                                             

Epoch 2 Summary | Time: 13.89s
  Train Loss: 0.2602
  Val Loss  : 0.2551 | Val AUC  : 0.8046


                                                                             

Epoch 3 Summary | Time: 14.02s
  Train Loss: 0.2591
  Val Loss  : 0.2557 | Val AUC  : 0.8109
  Saved best model with Val AUC: 0.8109


                                                                             

Epoch 4 Summary | Time: 13.88s
  Train Loss: 0.2585
  Val Loss  : 0.2553 | Val AUC  : 0.8121
  Saved best model with Val AUC: 0.8121


                                                                             

Epoch 5 Summary | Time: 13.96s
  Train Loss: 0.2582
  Val Loss  : 0.2556 | Val AUC  : 0.8134
  Saved best model with Val AUC: 0.8134


                                                                             

Epoch 6 Summary | Time: 14.00s
  Train Loss: 0.2579
  Val Loss  : 0.2547 | Val AUC  : 0.8125


                                                                             

Epoch 7 Summary | Time: 13.80s
  Train Loss: 0.2577
  Val Loss  : 0.2529 | Val AUC  : 0.8124


                                                                             

Epoch 8 Summary | Time: 13.95s
  Train Loss: 0.2574
  Val Loss  : 0.2528 | Val AUC  : 0.8137
  Saved best model with Val AUC: 0.8137


                                                                             

Epoch 9 Summary | Time: 14.05s
  Train Loss: 0.2572
  Val Loss  : 0.2542 | Val AUC  : 0.8133


                                                                              

Epoch 10 Summary | Time: 13.99s
  Train Loss: 0.2571
  Val Loss  : 0.2528 | Val AUC  : 0.8142
  Saved best model with Val AUC: 0.8142


                                                                              

Epoch 11 Summary | Time: 14.08s
  Train Loss: 0.2573
  Val Loss  : 0.2854 | Val AUC  : 0.8128


                                                                              

Epoch 12 Summary | Time: 14.22s
  Train Loss: 0.2570
  Val Loss  : 0.2527 | Val AUC  : 0.8133


                                                                              

Epoch 13 Summary | Time: 14.03s
  Train Loss: 0.2568
  Val Loss  : 0.2528 | Val AUC  : 0.8145
  Saved best model with Val AUC: 0.8145


                                                                              

Epoch 14 Summary | Time: 13.95s
  Train Loss: 0.2567
  Val Loss  : 0.2529 | Val AUC  : 0.8152
  Saved best model with Val AUC: 0.8152


                                                                              

Epoch 15 Summary | Time: 14.12s
  Train Loss: 0.2567
  Val Loss  : 0.2983 | Val AUC  : 0.8140


                                                                              

Epoch 16 Summary | Time: 14.09s
  Train Loss: 0.2565
  Val Loss  : 0.2717 | Val AUC  : 0.8163
  Saved best model with Val AUC: 0.8163


                                                                              

Epoch 17 Summary | Time: 14.05s
  Train Loss: 0.2564
  Val Loss  : 0.2524 | Val AUC  : 0.8150


                                                                              

Epoch 18 Summary | Time: 13.71s
  Train Loss: 0.2564
  Val Loss  : 0.2522 | Val AUC  : 0.8145


                                                                              

Epoch 19 Summary | Time: 14.03s
  Train Loss: 0.2562
  Val Loss  : 0.2519 | Val AUC  : 0.8152


                                                                              

Epoch 20 Summary | Time: 13.90s
  Train Loss: 0.2562
  Val Loss  : 0.2519 | Val AUC  : 0.8157


                                                                              

Epoch 21 Summary | Time: 14.09s
  Train Loss: 0.2561
  Val Loss  : 0.2521 | Val AUC  : 0.8159


                                                                              

Epoch 22 Summary | Time: 13.91s
  Train Loss: 0.2562
  Val Loss  : 0.2516 | Val AUC  : 0.8163
  Saved best model with Val AUC: 0.8163


                                                                              

Epoch 23 Summary | Time: 13.94s
  Train Loss: 0.2560
  Val Loss  : 0.2517 | Val AUC  : 0.8146


                                                                              

Epoch 24 Summary | Time: 13.97s
  Train Loss: 0.2557
  Val Loss  : 0.2512 | Val AUC  : 0.8169
  Saved best model with Val AUC: 0.8169


                                                                              

Epoch 25 Summary | Time: 14.03s
  Train Loss: 0.2560
  Val Loss  : 0.2523 | Val AUC  : 0.8160


                                                                              

Epoch 26 Summary | Time: 14.15s
  Train Loss: 0.2558
  Val Loss  : 0.2517 | Val AUC  : 0.8161


                                                                              

Epoch 27 Summary | Time: 13.98s
  Train Loss: 0.2557
  Val Loss  : 0.2521 | Val AUC  : 0.8155


                                                                              

Epoch 28 Summary | Time: 14.02s
  Train Loss: 0.2557
  Val Loss  : 0.2524 | Val AUC  : 0.8161


                                                                              

Epoch 29 Summary | Time: 13.85s
  Train Loss: 0.2556
  Val Loss  : 0.2515 | Val AUC  : 0.8163


                                                                              

Epoch 30 Summary | Time: 13.70s
  Train Loss: 0.2556
  Val Loss  : 0.2520 | Val AUC  : 0.8160
Classifier training finished.

--- Evaluating on Test Set ---


  classifier.load_state_dict(torch.load(os.path.join(save_path, model_name)))


Classifier (Original Embeddings) Test Loss: 0.2451, Test AUC: 0.8347
Classifier (Debiased Embeddings) Test Loss: 0.2511, Test AUC: 0.8187


                                                                                    


Adversary Final Test Loss: 2.5136
  Adversary Gender Prediction Test Accuracy: 0.6743
  Adversary Insurance Prediction Test Accuracy: 0.5946
  Adversary Race Prediction Test Accuracy: 0.6755

--- Script Finished ---




## 