In [1]:
import pandas as pd
import tensorflow as tf
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm
print("GPU Available: ", tf.config.list_physical_devices('GPU'))

# File path
file_path = "/kaggle/input/mimic-data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/SHA256SUMS.txt"

# Read the file and extract paths
with open(file_path, "r") as file:
    lines = [line.strip().split(maxsplit=1)[-1] for line in file if "files/" in line]  # Extract only paths

# Create a DataFrame
df = pd.DataFrame(lines, columns=["file_paths"])

# Show first few rows
print(df.head())


GPU Available:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
                                          file_paths
0  files/p10/p10000032/s50414267/02aa804e-bde0afd...
1  files/p10/p10000032/s53189527/2a2277a9-b0ded15...
2  files/p10/p10000032/s53911762/68b5c4b1-227d048...
3  files/p10/p10000032/s53911762/fffabebf-74fd3a1...
4  files/p10/p10000032/s56699142/ea030e7a-2e3b134...


In [None]:
import pandas as pd
import tensorflow as tf
import torch
from torch.utils.data import Dataset
import numpy as np
import os
import re
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
tf.get_logger().setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Base path for dataset
base_path = "/kaggle/input/mimic-data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0"

# Load labels
labels_path = "/kaggle/input/mimic-data/mimic-cxr-2.0.0-chexpert.csv"
try:
    labels_df = pd.read_csv(labels_path)
    labels_df['subject_id'] = labels_df['subject_id'].astype(str)
    labels_df['study_id'] = labels_df['study_id'].astype(str)
    logger.info(f"Loaded labels: {labels_df.shape[0]} rows")
except Exception as e:
    logger.error(f"Error loading labels: {e}")
    labels_df = pd.DataFrame(columns=['subject_id', 'study_id'])

# Label columns
label_columns = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
                'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 
                'Lung Opacity', 'No Finding', 'Pleural Effusion', 
                'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

# TFRecord feature description
feature_description = {
    'embedding': tf.io.FixedLenFeature([1376], tf.float32),
    'image/id': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string)
}

def extract_ids_from_path(path):
    """Extract subject_id and study_id from image path"""
    p_pattern = r'/p(\d+)/p(\d+)/s(\d+)/'
    match = re.search(p_pattern, path)
    
    if match:
        subject_id = match.group(2)
        study_id = match.group(3)
        return subject_id, study_id
    
    # Fallback pattern
    alt_pattern = r'p(\d+)/s(\d+)'
    alt_match = re.search(alt_pattern, path)
    if alt_match:
        subject_id = alt_match.group(1)
        study_id = alt_match.group(2)
        return subject_id, study_id
    
    return None, None

class MIMICEmbeddingDataset(Dataset):
    def __init__(self, file_paths, base_path, labels_df):
        self.file_paths = file_paths
        self.base_path = base_path
        self.labels_df = labels_df
        self.label_columns = label_columns
        self.data = []
        
        self.matched_count = 0
        self.unmatched_count = 0
        self.skipped_files = 0
        
        # Create labels lookup dictionary
        self.label_dict = {}
        if not self.labels_df.empty:
            for _, row in self.labels_df.iterrows():
                key = (row['subject_id'], row['study_id'])
                self.label_dict[key] = row[self.label_columns].to_dict()
                
        logger.info(f"Created label dictionary with {len(self.label_dict)} entries")
        
        # Test extraction with a sample path
        if self.file_paths:
            test_path = self.file_paths[0]
            full_test_path = os.path.join(self.base_path, test_path)
            if os.path.exists(full_test_path):
                self._test_extraction(full_test_path)
        
        # Load dataset
        logger.info("Loading TFRecord files...")
        self._load_data()
        logger.info(f"Records with matched labels: {self.matched_count}")
        logger.info(f"Records without matched labels: {self.unmatched_count}")
        logger.info(f"Skipped files: {self.skipped_files}")
    
    def _test_extraction(self, test_path):
        """Test ID extraction on a sample file"""
        try:
            dataset = tf.data.TFRecordDataset(test_path)
            for record in dataset:
                parsed = tf.io.parse_single_example(record, feature_description)
                image_id = parsed['image/id'].numpy().decode('utf-8')
                subject_id, study_id = extract_ids_from_path(image_id)
                logger.info(f"Sample image_id: {image_id}")
                logger.info(f"Extracted IDs: subject_id={subject_id}, study_id={study_id}")
                key = (subject_id, study_id)
                if key in self.label_dict:
                    logger.info(f"✓ Found matching entry in labels")
                else:
                    logger.warning(f"✗ No matching entry in labels")
                return
        except Exception as e:
            logger.warning(f"Error testing extraction: {e}")
        
    def _load_data(self):
        # Process files in batches
        batch_size = 500
        total_files = len(self.file_paths)
        
        for batch_start in range(0, total_files, batch_size):
            batch_end = min(batch_start + batch_size, total_files)
            batch_paths = self.file_paths[batch_start:batch_end]
            batch_num = batch_start // batch_size + 1
            total_batches = (total_files + batch_size - 1) // batch_size
            
            logger.info(f"Processing batch {batch_num}/{total_batches}")
            
            for path in tqdm(batch_paths, desc=f"Batch {batch_num}/{total_batches}"):
                full_path = os.path.join(self.base_path, path)
                
                if not os.path.exists(full_path):
                    self.skipped_files += 1
                    continue
                    
                try:
                    dataset = tf.data.TFRecordDataset(
                        full_path,
                        buffer_size=200,
                        num_parallel_reads=tf.data.experimental.AUTOTUNE
                    )
                    
                    batch_matched = 0
                    for record in dataset:
                        try:
                            parsed = tf.io.parse_single_example(record, feature_description)
                            embedding = parsed['embedding'].numpy()
                            image_id = parsed['image/id'].numpy().decode('utf-8')
                            subject_id, study_id = extract_ids_from_path(image_id)
                            
                            # Find matching labels
                            labels = None
                            if subject_id and study_id:
                                key = (subject_id, study_id)
                                if key in self.label_dict:
                                    labels = self.label_dict[key]
                                    self.matched_count += 1
                                    batch_matched += 1
                                else:
                                    self.unmatched_count += 1
                            else:
                                self.unmatched_count += 1
                            
                            # Add to dataset
                            self.data.append({
                                'embedding': embedding,
                                'image_id': image_id,
                                'subject_id': subject_id,
                                'study_id': study_id,
                                'labels': labels
                            })
                            
                        except (tf.errors.DataLossError, tf.errors.OutOfRangeError):
                            continue
                        except Exception:
                            continue
                
                except Exception:
                    self.skipped_files += 1
            
            # Log progress and free memory
            logger.info(f"Batch {batch_num} complete. Total records: {len(self.data)}")
            import gc
            gc.collect()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        embedding_tensor = torch.tensor(item['embedding'], dtype=torch.float32)
        
        result = {
            'embedding': embedding_tensor,
            'subject_id': item['subject_id'],
            'study_id': item['study_id']
        }
        
        # Process labels
        if item['labels'] is not None:
            # Convert to binary labels (handling NaN as 0)
            label_values = []
            for col in self.label_columns:
                value = item['labels'].get(col, 0)
                if pd.isna(value):
                    value = 0
                label_values.append(float(value))
            
            # Create label tensors
            labels_tensor = torch.tensor(label_values, dtype=torch.float32)
            result['labels'] = labels_tensor
            
            # Create one-hot encoding for positive findings
            positive_indices = [i for i, val in enumerate(label_values) if val == 1]
            one_hot = torch.zeros(len(self.label_columns))
            
            if positive_indices:
                for idx in positive_indices:
                    one_hot[idx] = 1
            else:
                # If no positives, mark as "No Finding"
                no_finding_idx = self.label_columns.index('No Finding')
                one_hot[no_finding_idx] = 1
                
            result['labels_one_hot'] = one_hot
        else:
            # Default labels if none available
            result['labels'] = torch.zeros(len(self.label_columns), dtype=torch.float32)
            result['labels_one_hot'] = torch.zeros(len(self.label_columns), dtype=torch.float32)
            no_finding_idx = self.label_columns.index('No Finding')
            result['labels_one_hot'][no_finding_idx] = 1
        
        return result

# Load data and create dataset
try:
    # Load file paths from SHA256SUMS.txt
    file_path = "/kaggle/input/mimic-data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/SHA256SUMS.txt"
    
    # Read the file and extract paths
    with open(file_path, "r") as file:
        lines = [line.strip().split(maxsplit=1)[-1] for line in file if "files/" in line]  # Extract only paths
    
    # Filter for TFRecord files
    file_paths = [path for path in lines if path.endswith('.tfrecord')]
    
    logger.info(f"Found {len(file_paths)} TFRecord files")
    if file_paths:
        logger.info(f"Sample paths: {file_paths[:2]}")
    
    # Create the dataset
    dataset = MIMICEmbeddingDataset(file_paths, base_path, labels_df)
    logger.info(f"Dataset size: {len(dataset)}")
    
    # Display sample
    if dataset.data:
        sample = dataset[0]
        logger.info(f"Embedding shape: {sample['embedding'].shape}")
        logger.info(f"Labels shape: {sample['labels'].shape}")
except Exception as e:
    logger.error(f"Error: {e}")

Batch 1/487: 100%|██████████| 500/500 [00:12<00:00, 40.55it/s]
Batch 2/487: 100%|██████████| 500/500 [00:12<00:00, 38.91it/s]
Batch 3/487: 100%|██████████| 500/500 [00:12<00:00, 39.10it/s]
Batch 4/487: 100%|██████████| 500/500 [00:12<00:00, 39.30it/s]
Batch 5/487: 100%|██████████| 500/500 [00:12<00:00, 38.65it/s]
Batch 6/487: 100%|██████████| 500/500 [00:12<00:00, 38.95it/s]
Batch 7/487: 100%|██████████| 500/500 [00:13<00:00, 35.94it/s]
Batch 8/487: 100%|██████████| 500/500 [00:22<00:00, 22.43it/s]
Batch 9/487: 100%|██████████| 500/500 [00:22<00:00, 22.44it/s]
Batch 10/487: 100%|██████████| 500/500 [00:23<00:00, 21.41it/s]
Batch 11/487: 100%|██████████| 500/500 [00:14<00:00, 34.58it/s]
Batch 12/487: 100%|██████████| 500/500 [00:12<00:00, 39.70it/s]
Batch 13/487: 100%|██████████| 500/500 [00:12<00:00, 40.46it/s]
Batch 14/487: 100%|██████████| 500/500 [00:12<00:00, 38.94it/s]
Batch 15/487: 100%|██████████| 500/500 [00:12<00:00, 39.04it/s]
Batch 16/487: 100%|██████████| 500/500 [00:12<00:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os
import seaborn as sns

# Load demographic data from MIMIC
def load_demographic_data():
    """
    Load patient demographics from MIMIC CSV files
    Returns: 
        admissions_df: DataFrame with admission information
        patients_df: DataFrame with patient information
    """
    admissions_df = pd.read_csv("/kaggle/input/mimic-patients/admissions.csv")
    patients_df = pd.read_csv("/kaggle/input/mimic-patients/patients.csv")
    
    # Check for duplicate subject_ids and handle them (e.g., keep most recent)
    if admissions_df['subject_id'].duplicated().any():
        print("Warning: duplicate subject_ids found in admissions data")
        # Keep most recent admission for each patient (assuming there's an admission_time column)
        if 'admittime' in admissions_df.columns:
            admissions_df = admissions_df.sort_values('admittime').drop_duplicates('subject_id', keep='last')
    
    # Only keep relevant columns
    admissions_cols = ['subject_id', 'insurance', 'marital_status', 'ethnicity']
    patients_cols = ['subject_id', 'gender', 'anchor_year_group']
    
    admissions_df = admissions_df[admissions_cols]
    patients_df = patients_df[patients_cols]
    
    return admissions_df, patients_df

# ChestXrayClassifier model definition (same as before)
class ChestXrayClassifier(nn.Module):
    def __init__(self, input_dim=1376, hidden_dims=[512, 384, 256], output_dim=14):
        super(ChestXrayClassifier, self).__init__()
        
        # Input normalization layer (learnable)
        self.batch_norm_input = nn.BatchNorm1d(input_dim)
        
        # Main network with residual connections
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.bn1 = nn.BatchNorm1d(hidden_dims[0])
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.bn2 = nn.BatchNorm1d(hidden_dims[1])
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.bn3 = nn.BatchNorm1d(hidden_dims[2])
        self.dropout3 = nn.Dropout(0.3)
        
        # Residual connection (from input to hidden layer 2)
        self.res_fc1 = nn.Linear(input_dim, hidden_dims[1])
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dims[2], hidden_dims[2] // 4),
            nn.ReLU(),
            nn.Linear(hidden_dims[2] // 4, hidden_dims[2]),
            nn.Sigmoid()
        )
        
        # Disease-specific layers for each output class
        self.disease_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[2] // 2),
                nn.ReLU(),
                nn.Linear(hidden_dims[2] // 2, 1)
            ) for _ in range(output_dim)
        ])
    
    def forward(self, x):
        # Input normalization
        x_norm = self.batch_norm_input(x)
        
        # First block with residual connection
        res = x_norm
        x = self.fc1(x_norm)
        x = self.bn1(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout1(x)
        
        # Residual connection from input to second layer
        res = self.res_fc1(res)
        
        # Second block
        x = self.fc2(x)
        x = self.bn2(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout2(x)
        
        # Add residual connection
        x = x + res
        
        # Third block
        x = self.fc3(x)
        x = self.bn3(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout3(x)
        
        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Disease-specific predictions
        outputs = []
        for disease_layer in self.disease_layers:
            outputs.append(disease_layer(x))
        
        # Concatenate all outputs
        return torch.cat(outputs, dim=1)

# Modified function to save predictions with demographic information
def save_predictions_to_csv(dataset_loader, model, device, label_columns, file_name, 
                          admissions_df=None, patients_df=None):
    """
    Generate and save model predictions with demographic information
    
    Args:
        dataset_loader: DataLoader with the dataset
        model: Trained model to generate predictions
        device: Device to run model on
        label_columns: List of disease label names
        file_name: Where to save the CSV
        admissions_df: DataFrame with admission information
        patients_df: DataFrame with patient information
    
    Returns:
        all_probs: Array of prediction probabilities
        all_targets: Array of ground truth labels
    """
    model.eval()
    all_outputs = []
    all_targets = []
    subject_ids = []
    study_ids = []
    image_ids = []
    
    with torch.no_grad():
        for batch in dataset_loader:
            inputs = batch['embedding'].to(device)
            targets = batch['labels_one_hot'].to(device)
            
            # Extract subject and study IDs from the batch
            if 'subject_id' in batch and 'study_id' in batch:
                subject_ids.extend(batch['subject_id'])
                study_ids.extend(batch['study_id'])
            else:
                # If IDs don't exist, create sequential placeholders
                batch_size = inputs.size(0)
                start_idx = len(subject_ids)
                subject_ids.extend([f"subject_{i}" for i in range(start_idx, start_idx + batch_size)])
                study_ids.extend([f"study_{i}" for i in range(start_idx, start_idx + batch_size)])
            
            # Get image IDs if available
            if 'image_id' in batch:
                image_ids.extend(batch['image_id'])
            else:
                batch_size = inputs.size(0)
                start_idx = len(image_ids)
                image_ids.extend([f"image_{i}" for i in range(start_idx, start_idx + batch_size)])
            
            outputs = model(inputs)
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    all_outputs = np.vstack(all_outputs)
    all_targets = np.vstack(all_targets)
    all_probs = 1 / (1 + np.exp(-all_outputs))  # sigmoid
    all_binary_preds = (all_probs >= 0.5).astype(int)  # Convert to binary predictions
    
    # Create DataFrame for predictions
    pred_df = pd.DataFrame()
    pred_df['subject_id'] = subject_ids
    pred_df['study_id'] = study_ids
    pred_df['image_id'] = image_ids
    
    # Add only ground truth and binary predictions for each disease
    for i, label in enumerate(label_columns):
        pred_df[f"{label}_true"] = all_targets[:, i]
        pred_df[f"{label}"] = all_binary_preds[:, i]  # Only binary predictions
    
    # Add demographic information if available
    if admissions_df is not None and patients_df is not None:
        # Convert subject_id to the same type before merging
        pred_df['subject_id'] = pred_df['subject_id'].astype(str)
        admissions_df['subject_id'] = admissions_df['subject_id'].astype(str)
        patients_df['subject_id'] = patients_df['subject_id'].astype(str)
        
        # Merge with admissions data
        pred_df = pd.merge(
            pred_df, 
            admissions_df[['subject_id', 'insurance', 'marital_status', 'ethnicity']], 
            on='subject_id', 
            how='left'
        )
        
        # Merge with patients data
        pred_df = pd.merge(
            pred_df, 
            patients_df[['subject_id', 'gender', 'anchor_year_group']], 
            on='subject_id', 
            how='left'
        )
    
    # Save to CSV
    pred_df.to_csv(file_name, index=False)
    print(f"Predictions saved to {file_name}")
    
    return all_probs, all_targets, pred_df

# Function to analyze predictions by demographic subgroups
def analyze_by_subgroups(predictions_df, label_columns, subgroup_cols=None):
    """
    Analyze model performance across different demographic subgroups
    
    Args:
        predictions_df: DataFrame with predictions and demographic info
        label_columns: List of disease labels
        subgroup_cols: List of demographic columns to analyze by
                      (default: ['gender', 'ethnicity', 'insurance', 'marital_status', 'anchor_year_group'])
    
    Returns:
        results_dict: Dictionary with performance metrics by subgroup
    """
    if subgroup_cols is None:
        subgroup_cols = ['gender', 'ethnicity', 'insurance', 'marital_status', 'anchor_year_group']
    
    results_dict = {}
    
    # Analyze each subgroup separately
    for group_col in subgroup_cols:
        if group_col not in predictions_df.columns:
            continue
            
        # Get unique values for this subgroup
        subgroups = predictions_df[group_col].dropna().unique()
        
        group_results = {}
        for subgroup in subgroups:
            # Filter for this subgroup
            subgroup_df = predictions_df[predictions_df[group_col] == subgroup]
            
            if len(subgroup_df) < 10:  # Skip if too few samples
                continue
                
            # Calculate metrics for each disease label
            disease_metrics = {}
            for label in label_columns:
                true_col = f"{label}_true"
                pred_col = f"{label}"
                
                # Extract ground truth and predictions
                y_true = subgroup_df[true_col].values
                y_pred = subgroup_df[pred_col].values
                
                # Skip if no positive examples
                if sum(y_true) == 0:
                    continue
                    
                # Calculate metrics
                accuracy = accuracy_score(y_true, y_pred)
                precision = precision_score(y_true, y_pred, zero_division=0)
                recall = recall_score(y_true, y_pred, zero_division=0)
                f1 = f1_score(y_true, y_pred, zero_division=0)
                
                # Store results
                disease_metrics[label] = {
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'sample_count': len(y_true),
                    'positive_count': sum(y_true)
                }
            
            group_results[subgroup] = {
                'sample_count': len(subgroup_df),
                'disease_metrics': disease_metrics
            }
        
        results_dict[group_col] = group_results
    
    return results_dict

# Function to create visualizations of subgroup performance
def visualize_subgroup_performance(subgroup_results, metric='f1', output_dir='subgroup_analysis'):
    """
    Create visualizations comparing model performance across subgroups
    
    Args:
        subgroup_results: Results dictionary from analyze_by_subgroups
        metric: Which metric to visualize ('accuracy', 'precision', 'recall', 'f1')
        output_dir: Directory to save plots
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for group_name, group_data in subgroup_results.items():
        # Get all diseases across all subgroups
        all_diseases = set()
        for subgroup, subgroup_data in group_data.items():
            all_diseases.update(subgroup_data['disease_metrics'].keys())
        
        # Create a dataframe for plotting
        plot_data = []
        for subgroup, subgroup_data in group_data.items():
            for disease, metrics in subgroup_data['disease_metrics'].items():
                plot_data.append({
                    'Subgroup': subgroup,
                    'Disease': disease,
                    metric: metrics[metric],
                    'Sample Count': metrics['sample_count']
                })
        
        if not plot_data:
            continue
            
        plot_df = pd.DataFrame(plot_data)
        
        # Create heatmap
        plt.figure(figsize=(12, 8))
        pivot_table = plot_df.pivot_table(
            values=metric, 
            index='Disease', 
            columns='Subgroup'
        )
        
        sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', fmt='.2f', linewidths=.5)
        plt.title(f'{metric.capitalize()} by {group_name} Subgroups')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group_name}_{metric}_heatmap.png")
        plt.close()
        
        # Create grouped bar plot
        plt.figure(figsize=(14, 10))
        sns.barplot(x='Disease', y=metric, hue='Subgroup', data=plot_df)
        plt.title(f'{metric.capitalize()} by Disease and {group_name}')
        plt.xticks(rotation=45, ha='right')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group_name}_{metric}_barplot.png")
        plt.close()

# Main training function with demographic integration
def train_and_evaluate(dataset, label_columns, test_size=0.2, batch_size=64, num_epochs=25):
    """
    Complete pipeline for training, evaluation, and subgroup analysis
    
    Args:
        dataset: The dataset to use
        label_columns: List of disease labels
        test_size: Fraction of data to use for testing
        batch_size: Batch size for training
        num_epochs: Number of training epochs
    """
    # Load demographic data
    print("Loading demographic data...")
    admissions_df, patients_df = load_demographic_data()
    print(f"Loaded admissions data for {len(admissions_df)} patients")
    print(f"Loaded patient data for {len(patients_df)} patients")
    
    # Create output directories
    os.makedirs('predictions', exist_ok=True)
    os.makedirs('subgroup_analysis', exist_ok=True)
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Split dataset
    dataset_size = len(dataset)
    test_count = int(test_size * dataset_size)
    train_count = dataset_size - test_count
    train_dataset, test_dataset = random_split(
        dataset, [train_count, test_count], 
        generator=torch.Generator().manual_seed(42)
    )
    print(f"Training on {len(train_dataset)} samples, testing on {len(test_dataset)} samples")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Initialize model
    model = ChestXrayClassifier(input_dim=1376, output_dim=len(label_columns))
    model.to(device)
    
    # Initialize loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training setup
    train_losses = []
    test_losses = []
    test_aucs = []
    
    # Training loop
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            # Move data to device
            inputs = batch['embedding'].to(device)
            targets = batch['labels_one_hot'].to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item() * inputs.size(0)
            progress_bar.set_postfix({'loss': loss.item()})
        
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Evaluation
        model.eval()
        test_loss = 0.0
        all_outputs = []
        all_targets = []
        
        with torch.no_grad():
            for batch in test_loader:
                inputs = batch['embedding'].to(device)
                targets = batch['labels_one_hot'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item() * inputs.size(0)
                all_outputs.append(outputs.cpu().numpy())
                all_targets.append(targets.cpu().numpy())
        
        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)
        
        # Calculate AUC for each label
        all_outputs = np.vstack(all_outputs)
        all_targets = np.vstack(all_targets)
        all_probs = 1 / (1 + np.exp(-all_outputs))  # sigmoid
        
        aucs = {}
        for i, label in enumerate(label_columns):
            if sum(all_targets[:, i]) > 0:  # Only if there are positive examples
                aucs[label] = roc_auc_score(all_targets[:, i], all_probs[:, i])
        
        mean_auc = np.mean(list(aucs.values()))
        test_aucs.append(mean_auc)
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Test Loss: {test_loss:.4f}")
        print(f"  Mean AUC: {mean_auc:.4f}")
        
        # Save predictions with demographic data every 5 epochs and at the last epoch
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            train_probs, train_targets, train_pred_df = save_predictions_to_csv(
                train_loader, model, device, label_columns, 
                f'predictions/train_predictions_epoch_{epoch+1}.csv',
                admissions_df, patients_df
            )
            
            test_probs, test_targets, test_pred_df = save_predictions_to_csv(
                test_loader, model, device, label_columns, 
                f'predictions/test_predictions_epoch_{epoch+1}.csv',
                admissions_df, patients_df
            )
            
            # Perform subgroup analysis at final epoch
            if epoch == num_epochs - 1:
                print("\nPerforming subgroup analysis on test predictions...")
                subgroup_results = analyze_by_subgroups(test_pred_df, label_columns)
                
                # Save subgroup analysis results
                with open('subgroup_analysis/subgroup_results.txt', 'w') as f:
                    for group, group_data in subgroup_results.items():
                        f.write(f"===== {group} =====\n")
                        for subgroup, metrics in group_data.items():
                            f.write(f"\n-- {subgroup} (n={metrics['sample_count']}) --\n")
                            for disease, disease_metrics in metrics['disease_metrics'].items():
                                f.write(f"  {disease}:\n")
                                for metric_name, value in disease_metrics.items():
                                    if metric_name in ['accuracy', 'precision', 'recall', 'f1']:
                                        f.write(f"    {metric_name}: {value:.4f}\n")
                        f.write("\n\n")
                
                # Create visualizations
                print("Creating subgroup performance visualizations...")
                visualize_subgroup_performance(subgroup_results, metric='f1')
                visualize_subgroup_performance(subgroup_results, metric='accuracy')
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(test_aucs, label='Mean AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.show()
    
    # Final evaluation and save predictions
    print("\nGenerating final prediction CSVs...")
    train_probs, train_targets, train_pred_df = save_predictions_to_csv(
        train_loader, model, device, label_columns, 
        'predictions/train_predictions_final.csv',
        admissions_df, patients_df
    )
    
    test_probs, test_targets, test_pred_df = save_predictions_to_csv(
        test_loader, model, device, label_columns, 
        'predictions/test_predictions_final.csv',
        admissions_df, patients_df
    )
    
    # Calculate final metrics
    test_preds = (test_probs >= 0.5).astype(int)
    metrics_df = pd.DataFrame(columns=['Label', 'AUC', 'Accuracy', 'Precision', 'Recall', 'F1'])
    
    print("\nFinal metrics by condition:")
    for i, label in enumerate(label_columns):
        # Skip if no positive examples
        if sum(test_targets[:, i]) > 0:
            auc = roc_auc_score(test_targets[:, i], test_probs[:, i])
            accuracy = accuracy_score(test_targets[:, i], test_preds[:, i])
            precision = precision_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            recall = recall_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            f1 = f1_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            
            print(f"{label}:")
            print(f"  AUC: {auc:.4f}")
            print(f"  Accuracy: {accuracy:.4f}")
            print(f"  Precision: {precision:.4f}")
            print(f"  Recall: {recall:.4f}")
            print(f"  F1: {f1:.4f}")
            
            # Add to metrics dataframe
            metrics_df = pd.concat([
                metrics_df, 
                pd.DataFrame({
                    'Label': [label],
                    'AUC': [auc],
                    'Accuracy': [accuracy],
                    'Precision': [precision],
                    'Recall': [recall],
                    'F1': [f1]
                })
            ], ignore_index=True)
    
    print(f"\nOverall Mean AUC: {metrics_df['AUC'].mean():.4f}")
    metrics_df.to_csv('predictions/metrics_summary.csv', index=False)
    
    # Save model
    torch.save(model.state_dict(), 'chest_xray_model.pth')
    print("Model saved to 'chest_xray_model.pth'")
    
    # Generate demographic fairness report
    print("\nGenerating fairness analysis across demographic groups...")
    fairness_df = pd.DataFrame(columns=['Group', 'Subgroup', 'Label', 'AUC', 'Accuracy', 'F1', 'Count'])
    
    for group_name, group_data in subgroup_results.items():
        for subgroup, subgroup_data in group_data.items():
            for disease, metrics in subgroup_data['disease_metrics'].items():
                fairness_df = pd.concat([
                    fairness_df,
                    pd.DataFrame({
                        'Group': [group_name],
                        'Subgroup': [subgroup],
                        'Label': [disease],
                        'AUC': [metrics.get('auc', 0)],
                        'Accuracy': [metrics['accuracy']],
                        'F1': [metrics['f1']],
                        'Count': [metrics['sample_count']]
                    })
                ], ignore_index=True)
    
    fairness_df.to_csv('predictions/fairness_analysis.csv', index=False)
    
    # Create disparity visualizations
    plt.figure(figsize=(15, 10))
    for metric in ['Accuracy', 'F1']:
        for group in fairness_df['Group'].unique():
            group_data = fairness_df[fairness_df['Group'] == group]
            
            plt.figure(figsize=(12, 8))
            sns.boxplot(x='Label', y=metric, hue='Subgroup', data=group_data)
            plt.title(f'{metric} Disparities by {group}')
            plt.xticks(rotation=45, ha='right')
            plt.tight_layout()
            plt.savefig(f'subgroup_analysis/{group}_{metric}_disparity.png')
            plt.close()
    
    # Summary analysis: Create a disparity metric
    # (max performance - min performance) for each disease/group pair
    disparity_df = fairness_df.groupby(['Group', 'Label']).agg({
        'F1': ['min', 'max', lambda x: max(x) - min(x)],
        'Accuracy': ['min', 'max', lambda x: max(x) - min(x)],
        'Count': 'sum'
    }).reset_index()
    
    # Rename the columns for clarity
    disparity_df.columns = [
        'Group', 'Label', 'F1_Min', 'F1_Max', 'F1_Disparity', 
        'Acc_Min', 'Acc_Max', 'Acc_Disparity', 'Total_Count'
    ]
    
    # Sort by F1 disparity (largest disparities first)
    disparity_df = disparity_df.sort_values('F1_Disparity', ascending=False)
    disparity_df.to_csv('predictions/performance_disparities.csv', index=False)
    
    # Visualization of the largest disparities
    top_disparities = disparity_df.head(10)  # Top 10 disparities
    
    plt.figure(figsize=(12, 8))
    sns.barplot(x='Label', y='F1_Disparity', hue='Group', data=top_disparities)
    plt.title('Top Performance Disparities (F1 Score)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig('subgroup_analysis/top_disparities.png')
    
    # Perform TPR analysis
    print("\nPerforming TPR analysis by demographic subgroups...")
    plot_subgroup_tpr_and_disparities(test_pred_df, label_columns)
    
    # Statistical analysis of TPR disparities
    tpr_disparities = analyze_tpr_fairness(test_pred_df, label_columns)
    tpr_disparities.to_csv('tpr_analysis/significant_tpr_disparities.csv', index=False)
    
    # Print significant disparities
    significant_disparities = tpr_disparities[tpr_disparities['Significant']]
    if len(significant_disparities) > 0:
        print("\nSignificant TPR disparities found:")
        for _, row in significant_disparities.sort_values('TPR_Disparity', ascending=False).head(5).iterrows():
            print(f"  {row['Disease']} - {row['Demographic_Group']}: {row['Reference']} vs {row['Subgroup']}")
            print(f"    TPR disparity: {row['TPR_Disparity']:.4f} (p-value: {row['P_Value']:.4f})")
    
    return model, metrics_df, fairness_df, disparity_df, tpr_disparities

# Function for detailed analysis on a specific demographic group
def analyze_specific_subgroup(predictions_df, label_columns, group_col, subgroup_value):
    """
    Perform detailed analysis on a specific demographic subgroup
    
    Args:
        predictions_df: DataFrame with predictions and demographic data
        label_columns: List of disease labels
        group_col: Column name for the demographic group (e.g., 'gender', 'ethnicity')
        subgroup_value: Specific value to analyze (e.g., 'F', 'White')
        
    Returns:
        subgroup_metrics: Dictionary with detailed metrics
    """
    # Filter for the specific subgroup
    subgroup_df = predictions_df[predictions_df[group_col] == subgroup_value]
    
    if len(subgroup_df) == 0:
        print(f"No samples found for {group_col}={subgroup_value}")
        return None
    
    print(f"Analyzing {len(subgroup_df)} samples for {group_col}={subgroup_value}")
    
    # Calculate metrics for each disease
    disease_metrics = {}
    for label in label_columns:
        true_col = f"{label}_true"
        pred_col = f"{label}"
        
        # Extract ground truth and predictions
        y_true = subgroup_df[true_col].values
        y_pred = subgroup_df[pred_col].values
        y_proba = np.array([0.5] * len(y_pred))  # Use 0.5 as placeholder if probabilities aren't available
        
        # Skip if no positive examples
        if sum(y_true) == 0:
            continue
            
        # Calculate confusion matrix elements
        tn = np.sum((y_true == 0) & (y_pred == 0))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))
        tp = np.sum((y_true == 1) & (y_pred == 1))
        
        # Calculate metrics
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        # Store detailed metrics
        disease_metrics[label] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'specificity': specificity,
            'f1': f1,
            'true_positives': int(tp),
            'false_positives': int(fp),
            'true_negatives': int(tn),
            'false_negatives': int(fn),
            'positive_count': int(sum(y_true)),
            'sample_count': len(y_true),
            'positive_rate': sum(y_true) / len(y_true)
        }
    
    # Create demographics breakdown for this subgroup
    demographics = {}
    for demo_col in ['gender', 'ethnicity', 'insurance', 'marital_status', 'anchor_year_group']:
        if demo_col != group_col and demo_col in subgroup_df.columns:
            demographics[demo_col] = subgroup_df[demo_col].value_counts().to_dict()
    
    subgroup_metrics = {
        'sample_count': len(subgroup_df),
        'disease_metrics': disease_metrics,
        'demographics': demographics
    }
    
    return subgroup_metrics

# Function to analyze TPR and disparities by subgroup
def plot_subgroup_tpr_and_disparities(predictions_df, label_columns, output_dir='tpr_analysis'):
    """
    Create plots showing:
    1. TPR (recall) for each subgroup within each demographic group
    2. Disparity from overall TPR for each subgroup
    
    Args:
        predictions_df: DataFrame with predictions and demographic data
        label_columns: List of disease labels
        output_dir: Directory to save output plots
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Define demographic groups to analyze
    demographic_groups = ['gender', 'ethnicity', 'insurance', 'marital_status', 'anchor_year_group']
    demographic_groups = [g for g in demographic_groups if g in predictions_df.columns]
    
    # First, calculate overall TPR for each disease
    overall_tpr = {}
    for label in label_columns:
        true_col = f"{label}_true"
        pred_col = f"{label}"
        
        # Calculate overall TPR (recall)
        y_true = predictions_df[true_col].values
        y_pred = predictions_df[pred_col].values
        
        # Skip if no positive cases
        if sum(y_true) == 0:
            continue
            
        overall_tpr[label] = recall_score(y_true, y_pred, zero_division=0)
    
    # Now calculate TPR for each subgroup and the disparity
    tpr_data = []
    
    for group in demographic_groups:
        for label in overall_tpr.keys():
            true_col = f"{label}_true"
            pred_col = f"{label}"
            
            # Calculate TPR for each subgroup in this demographic group
            for subgroup in predictions_df[group].dropna().unique():
                subgroup_df = predictions_df[predictions_df[group] == subgroup]
                
                # Skip if too few samples
                if len(subgroup_df) < 10:
                    continue
                
                y_true = subgroup_df[true_col].values
                y_pred = subgroup_df[pred_col].values
                
                # Skip if no positive cases in this subgroup
                if sum(y_true) == 0:
                    continue
                
                # Calculate TPR for this subgroup
                subgroup_tpr = recall_score(y_true, y_pred, zero_division=0)
                
                # Calculate disparity from overall TPR
                tpr_disparity = subgroup_tpr - overall_tpr[label]
                
                # Add to data collection
                tpr_data.append({
                    'Demographic_Group': group,
                    'Subgroup': subgroup,
                    'Disease': label,
                    'TPR': subgroup_tpr,
                    'Overall_TPR': overall_tpr[label],
                    'TPR_Disparity': tpr_disparity,
                    'Sample_Count': len(y_true),
                    'Positive_Count': sum(y_true)
                })
    
    # Convert to DataFrame
    tpr_df = pd.DataFrame(tpr_data)
    
    # Save the data
    tpr_df.to_csv(f"{output_dir}/tpr_by_subgroup.csv", index=False)
    
    # Create plots for each demographic group
    for group in demographic_groups:
        group_data = tpr_df[tpr_df['Demographic_Group'] == group]
        
        if len(group_data) == 0:
            continue
        
        # Plot 1: TPR by subgroup for each disease
        plt.figure(figsize=(14, 8))
        sns.barplot(x='Disease', y='TPR', hue='Subgroup', data=group_data)
        plt.title(f'True Positive Rate by {group} Subgroup')
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 1)  # TPR ranges from 0 to 1
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group}_tpr_by_subgroup.png")
        plt.close()
        
        # Plot 2: TPR Disparity from overall TPR
        plt.figure(figsize=(14, 8))
        sns.barplot(x='Disease', y='TPR_Disparity', hue='Subgroup', data=group_data)
        plt.title(f'TPR Disparity from Overall by {group} Subgroup')
        plt.xticks(rotation=45, ha='right')
        plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)  # Add a line at y=0
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group}_tpr_disparity.png")
        plt.close()
        
        # Plot 3: Heatmap of TPR by subgroup and disease
        if len(group_data['Subgroup'].unique()) > 1 and len(group_data['Disease'].unique()) > 1:
            plt.figure(figsize=(12, 8))
            pivot_tpr = group_data.pivot_table(
                values='TPR', 
                index='Disease', 
                columns='Subgroup'
            )
            sns.heatmap(pivot_tpr, annot=True, cmap='YlGnBu', fmt='.2f', linewidths=.5)
            plt.title(f'TPR Heatmap by {group} Subgroup')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/{group}_tpr_heatmap.png")
            plt.close()
            
            # Plot 4: Heatmap of TPR disparities
            plt.figure(figsize=(12, 8))
            pivot_disparity = group_data.pivot_table(
                values='TPR_Disparity', 
                index='Disease', 
                columns='Subgroup'
            )
            # Use a diverging colormap centered at 0
            sns.heatmap(pivot_disparity, annot=True, cmap='RdBu_r', fmt='.2f', 
                         linewidths=.5, center=0)
            plt.title(f'TPR Disparity Heatmap by {group} Subgroup')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/{group}_tpr_disparity_heatmap.png")
            plt.close()

# Function to perform statistical analysis of TPR disparities
def analyze_tpr_fairness(predictions_df, label_columns, significance_threshold=0.05):
    """
    Perform statistical analysis on TPR disparities to identify significant differences
    
    Args:
        predictions_df: DataFrame with predictions and demographic info
        label_columns: List of disease labels
        significance_threshold: p-value threshold for statistical significance
        
    Returns:
        DataFrame with significant TPR disparities
    """
    from scipy.stats import fisher_exact
    
    demographic_groups = ['gender', 'ethnicity', 'insurance', 'marital_status', 'anchor_year_group']
    demographic_groups = [g for g in demographic_groups if g in predictions_df.columns]
    
    result_data = []
    
    for group in demographic_groups:
        # Get the majority subgroup as reference
        reference = predictions_df[group].value_counts().index[0]
        
        for label in label_columns:
            true_col = f"{label}_true"
            pred_col = f"{label}"
            
            # Get reference subgroup contingency table 
            ref_df = predictions_df[predictions_df[group] == reference]
            ref_tp = np.sum((ref_df[true_col] == 1) & (ref_df[pred_col] == 1))
            ref_fn = np.sum((ref_df[true_col] == 1) & (ref_df[pred_col] == 0))
            ref_tpr = ref_tp / (ref_tp + ref_fn) if (ref_tp + ref_fn) > 0 else 0
            
            # Compare with other subgroups
            for subgroup in predictions_df[group].dropna().unique():
                if subgroup == reference:
                    continue
                
                subgroup_df = predictions_df[predictions_df[group] == subgroup]
                
                # Skip if too few samples
                if len(subgroup_df) < 20:
                    continue
                
                # Calculate TPR
                sg_tp = np.sum((subgroup_df[true_col] == 1) & (subgroup_df[pred_col] == 1))
                sg_fn = np.sum((subgroup_df[true_col] == 1) & (subgroup_df[pred_col] == 0))
                sg_tpr = sg_tp / (sg_tp + sg_fn) if (sg_tp + sg_fn) > 0 else 0
                
                # Skip if no positive cases
                if (ref_tp + ref_fn) == 0 or (sg_tp + sg_fn) == 0:
                    continue
                
                # Statistical test (Fisher's exact test on the contingency table)
                contingency = np.array([[ref_tp, ref_fn], [sg_tp, sg_fn]])
                odds_ratio, p_value = fisher_exact(contingency)
                
                tpr_disparity = sg_tpr - ref_tpr
                significant = p_value < significance_threshold
                
                result_data.append({
                    'Demographic_Group': group,
                    'Reference': reference,
                    'Subgroup': subgroup,
                    'Disease': label,
                    'Reference_TPR': ref_tpr,
                    'Subgroup_TPR': sg_tpr,
                    'TPR_Disparity': tpr_disparity,
                    'P_Value': p_value,
                    'Significant': significant,
                    'Reference_Positive_Count': ref_tp + ref_fn,
                    'Subgroup_Positive_Count': sg_tp + sg_fn
                })
    
    result_df = pd.DataFrame(result_data)
    return result_df

# Function for fairness gap analysis across multiple groups
def fairness_gap_analysis(predictions_df, label_columns, reference_groups=None):
    """
    Calculate fairness gaps (performance differences) between demographic groups
    
    Args:
        predictions_df: DataFrame with predictions and demographic data
        label_columns: List of disease labels
        reference_groups: Dictionary mapping demographic columns to reference groups
                         (e.g., {'gender': 'M', 'ethnicity': 'White'})
                         
    Returns:
        gaps_df: DataFrame with fairness gaps
    """
    if reference_groups is None:
        # Default reference groups (typically majority groups in each category)
        reference_groups = {
            'gender': predictions_df['gender'].value_counts().index[0],
            'ethnicity': 'White',
            'insurance': 'Private',
            'marital_status': 'Married'
        }
    
    gaps_data = []
    
    # For each demographic group
    for group_col, ref_group in reference_groups.items():
        if group_col not in predictions_df.columns:
            continue
            
        # Skip if reference group doesn't exist in the data
        if ref_group not in predictions_df[group_col].values:
            continue
            
        # Get metrics for reference group
        ref_metrics = analyze_specific_subgroup(predictions_df, label_columns, group_col, ref_group)
        
        if ref_metrics is None:
            continue
            
        # Compare with other groups
        for other_group in predictions_df[group_col].dropna().unique():
            if other_group == ref_group:
                continue
                
            other_metrics = analyze_specific_subgroup(predictions_df, label_columns, group_col, other_group)
            
            if other_metrics is None:
                continue
                
            # Calculate gaps for each disease and metric
            for disease in ref_metrics['disease_metrics'].keys():
                if disease not in other_metrics['disease_metrics']:
                    continue
                    
                for metric in ['accuracy', 'precision', 'recall', 'f1']:
                    ref_value = ref_metrics['disease_metrics'][disease][metric]
                    other_value = other_metrics['disease_metrics'][disease][metric]
                    gap = other_value - ref_value
                    
                    gaps_data.append({
                        'Group': group_col,
                        'Reference': ref_group,
                        'Comparison': other_group,
                        'Disease': disease,
                        'Metric': metric,
                        'Reference_Value': ref_value,
                        'Comparison_Value': other_value,
                        'Gap': gap,
                        'Relative_Gap': gap / ref_value if ref_value > 0 else 0,
                        'Reference_Count': ref_metrics['disease_metrics'][disease]['sample_count'],
                        'Comparison_Count': other_metrics['disease_metrics'][disease]['sample_count']
                    })
    
    # Convert to DataFrame
    gaps_df = pd.DataFrame(gaps_data)
    
    # Add an absolute gap column for sorting
    gaps_df['Absolute_Gap'] = gaps_df['Gap'].abs()
    
    return gaps_df

# Main execution
if __name__ == "__main__":
    # Assuming dataset and label_columns are defined elsewhere
    
    # Train the model and perform demographic analysis
    model, metrics_df, fairness_df, disparity_df, tpr_disparities = train_and_evaluate(
        dataset=dataset,
        label_columns=label_columns,
        num_epochs=25
    )
    
    print("\nTraining complete. Analysis saved to 'predictions', 'subgroup_analysis', and 'tpr_analysis' directories.")

In [None]:
pd.read_csv("/kaggle/input/mimic-patients/admissions.csv").head()

In [None]:
pd.read_csv("/kaggle/input/mimic-patients/patients.csv").head()