In [1]:
import pandas as pd
import tensorflow as tf
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm
print("GPU Available: ", tf.config.list_physical_devices('GPU'))


import pickle

# Path to the file
file_path = "/kaggle/input/preprocessed-mimic-train-test/preprocessed_data_train_test.pkl"

# Load the data
with open(file_path, 'rb') as file:
    meta_data = pickle.load(file)

# Now you can work with the loaded data
# Let's see what's inside
print(type(meta_data))
print(meta_data.keys())

GPU Available:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]
<class 'pandas.core.frame.DataFrame'>
Index(['path', 'subject_id', 'study_id', 'dicom_id', 'split', 'gender',
       'insurance', 'anchor_age', 'race', 'Enlarged Cardiomediastinum',
       'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation',
       'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
       'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'],
      dtype='object')


In [2]:
meta_data.head()

Unnamed: 0,path,subject_id,study_id,dicom_id,split,gender,insurance,anchor_age,race,Enlarged Cardiomediastinum,...,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,No Finding
0,generalized-image-embeddings-for-the-mimic-che...,10000032,50414267,02aa804e-bde0afdd-112c0b34-7bc16630-4e384014,train,F,Medicaid,52.0,WHITE,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1,generalized-image-embeddings-for-the-mimic-che...,10000032,53189527,2a2277a9-b0ded155-c0de8eb9-c124d10e-82c5caab,train,F,Medicaid,52.0,WHITE,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
2,generalized-image-embeddings-for-the-mimic-che...,10000032,53911762,68b5c4b1-227d0485-9cc38c3f-7b84ab51-4b472714,train,F,Medicaid,52.0,WHITE,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3,generalized-image-embeddings-for-the-mimic-che...,10000032,53911762,fffabebf-74fd3a1f-673b6b41-96ec0ac9-2ab69818,train,F,Medicaid,52.0,WHITE,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
4,generalized-image-embeddings-for-the-mimic-che...,10000032,56699142,ea030e7a-2e3b1346-bc518786-7a8fd698-f673b44c,train,F,Medicaid,52.0,WHITE,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [3]:
import pandas as pd
import tensorflow as tf
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import os
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
tf.get_logger().setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Base path for dataset
base_path = "/kaggle/input/mimic-data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/"

# Label columns
label_columns = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 
                'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 
                'Pleural Other', 'Fracture', 'Support Devices', 'No Finding']

# TFRecord feature description
feature_description = {
    'embedding': tf.io.FixedLenFeature([1376], tf.float32),
    'image/id': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string)
}

def parse_tfrecord(record):
    """Parse a TFRecord into a dictionary of features"""
    example = tf.io.parse_single_example(record, feature_description)
    return {
        'embedding': example['embedding'].numpy(),
        'image_id': example['image/id'].numpy().decode('utf-8')
    }

class MIMICEmbeddingDataset(Dataset):
    def __init__(self, metadata_df, base_path):
        self.metadata_df = metadata_df  # Your existing metadata DataFrame
        self.base_path = base_path
        self.label_columns = label_columns
        self.data = []
        
        # Process each row in the metadata
        logger.info("Processing metadata to connect with embeddings...")
        self._load_data()
        logger.info(f"Dataset size: {len(self.data)}")
    
    def _load_data(self):
        skipped_files = 0
        
        # Process each row in the metadata
        for idx, row in tqdm(self.metadata_df.iterrows(), total=len(self.metadata_df), desc="Processing metadata"):
            # Get the relative path
            tfrecord_path = os.path.join(self.base_path, row['path'])
            
            if not os.path.exists(tfrecord_path):
                logger.warning(f"File not found: {tfrecord_path}")
                skipped_files += 1
                continue
            
            # Load the TFRecord
            dataset = tf.data.TFRecordDataset(tfrecord_path)
            
            for record in dataset:
                # Parse the TFRecord
                parsed = parse_tfrecord(record)
                embedding = parsed['embedding']
                
                # Add to dataset with metadata
                self.data.append({
                    'embedding': embedding,
                    'subject_id': row['subject_id'],
                    'study_id': row['study_id'],
                    'labels': {col: row[col] for col in self.label_columns},
                    'split': row['split'],
                    'demographics': {
                        'gender': row['gender'],
                        'insurance': row['insurance'],
                        'race': row['race'],
                        'anchor_age': row['anchor_age']
                    }
                })
                
                # We only need one record per row
                break
        
        logger.info(f"Processed {len(self.data)} records. Skipped files: {skipped_files}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        embedding_tensor = torch.tensor(item['embedding'], dtype=torch.float32)
        
        result = {
            'embedding': embedding_tensor,
            'subject_id': item['subject_id'],
            'study_id': item['study_id'],
            'split': item['split']
        }
        
        # Process labels
        label_values = []
        for col in self.label_columns:
            value = item['labels'].get(col, 0)
            if pd.isna(value):
                value = 0
            label_values.append(float(value))
        
        result['labels'] = torch.tensor(label_values, dtype=torch.float32)
        
        return result

# Create dataset with your existing meta_data DataFrame
dataset = MIMICEmbeddingDataset(meta_data, base_path)

# Create train and test datasets based on the split column
train_indices = [i for i, item in enumerate(dataset.data) if item['split'] == 'train']
test_indices = [i for i, item in enumerate(dataset.data) if item['split'] == 'test']

# Split the train set into train and validation (90/10)
train_size = int(0.9 * len(train_indices))
val_size = len(train_indices) - train_size

# Create a random generator with fixed seed for reproducibility
generator = torch.Generator().manual_seed(42)

# Split train indices into train and val
train_indices, val_indices = random_split(train_indices, [train_size, val_size], generator=generator)

# Create subset datasets
class SubsetDatasetWithIndices(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

train_dataset = SubsetDatasetWithIndices(dataset, train_indices.indices)
val_dataset = SubsetDatasetWithIndices(dataset, val_indices.indices)
test_dataset = SubsetDatasetWithIndices(dataset, test_indices)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Print dataset sizes
logger.info(f"Total dataset size: {len(dataset)}")
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Validation dataset size: {len(val_dataset)}")
logger.info(f"Test dataset size: {len(test_dataset)}")

# Test the datasets
if len(train_dataset) > 0:
    sample = train_dataset[0]
    logger.info(f"Sample embedding shape: {sample['embedding'].shape}")
    logger.info(f"Sample labels shape: {sample['labels'].shape}")

Processing metadata: 100%|██████████| 228905/228905 [1:44:50<00:00, 36.39it/s]  


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os
import seaborn as sns
from scipy.stats import fisher_exact

# ChestXrayClassifier model definition (unchanged)
class ChestXrayClassifier(nn.Module):
    def __init__(self, input_dim=1376, hidden_dims=[512, 384, 256], output_dim=14):
        super(ChestXrayClassifier, self).__init__()
        
        # Input normalization layer (learnable)
        self.batch_norm_input = nn.BatchNorm1d(input_dim)
        
        # Main network with residual connections
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.bn1 = nn.BatchNorm1d(hidden_dims[0])
        self.dropout1 = nn.Dropout(0.3)
        
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.bn2 = nn.BatchNorm1d(hidden_dims[1])
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.bn3 = nn.BatchNorm1d(hidden_dims[2])
        self.dropout3 = nn.Dropout(0.3)
        
        # Residual connection (from input to hidden layer 2)
        self.res_fc1 = nn.Linear(input_dim, hidden_dims[1])
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dims[2], hidden_dims[2] // 4),
            nn.ReLU(),
            nn.Linear(hidden_dims[2] // 4, hidden_dims[2]),
            nn.Sigmoid()
        )
        
        # Disease-specific layers for each output class
        self.disease_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[2] // 2),
                nn.ReLU(),
                nn.Linear(hidden_dims[2] // 2, 1)
            ) for _ in range(output_dim)
        ])
    
    def forward(self, x):
        # Input normalization
        x_norm = self.batch_norm_input(x)
        
        # First block with residual connection
        res = x_norm
        x = self.fc1(x_norm)
        x = self.bn1(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout1(x)
        
        # Residual connection from input to second layer
        res = self.res_fc1(res)
        
        # Second block
        x = self.fc2(x)
        x = self.bn2(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout2(x)
        
        # Add residual connection
        x = x + res
        
        # Third block
        x = self.fc3(x)
        x = self.bn3(x)
        x = nn.functional.leaky_relu(x, 0.1)
        x = self.dropout3(x)
        
        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Disease-specific predictions
        outputs = []
        for disease_layer in self.disease_layers:
            outputs.append(disease_layer(x))
        
        # Concatenate all outputs
        return torch.cat(outputs, dim=1)

# Modified function to save predictions with demographic information
def save_predictions_to_csv(dataset_loader, model, device, label_columns, file_name):
    """
    Generate and save model predictions
    
    Args:
        dataset_loader: DataLoader with the dataset
        model: Trained model to generate predictions
        device: Device to run model on
        label_columns: List of disease label names
        file_name: Where to save the CSV
    
    Returns:
        all_probs: Array of prediction probabilities
        all_targets: Array of ground truth labels
    """
    model.eval()
    all_outputs = []
    all_targets = []
    subject_ids = []
    study_ids = []
    splits = []
    demo_data = {
        'gender': [],
        'race': [],
        'insurance': [],
        'anchor_age': []
    }
    
    with torch.no_grad():
        for batch in tqdm(dataset_loader, desc="Generating predictions"):
            inputs = batch['embedding'].to(device)
            targets = batch['labels'].to(device)
            
            # Extract IDs and split
            subject_ids.extend(batch['subject_id'])
            study_ids.extend(batch['study_id'])
            splits.extend(batch['split'])
            
            # Extract demographic info
            for key in demo_data:
                if key in batch:
                    demo_data[key].extend(batch[key])
            
            outputs = model(inputs)
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    all_outputs = np.vstack(all_outputs)
    all_targets = np.vstack(all_targets)
    all_probs = 1 / (1 + np.exp(-all_outputs))  # sigmoid
    all_binary_preds = (all_probs >= 0.5).astype(int)  # Convert to binary predictions
    
    # Create DataFrame for predictions
    pred_df = pd.DataFrame()
    pred_df['subject_id'] = subject_ids
    pred_df['study_id'] = study_ids
    pred_df['split'] = splits
    
    # Add demographics that we have
    for key, values in demo_data.items():
        if values:  # Only add if we have values
            pred_df[key] = values
    
    # Add true labels and predictions for each disease
    for i, label in enumerate(label_columns):
        pred_df[f"{label}_true"] = all_targets[:, i]
        pred_df[f"{label}"] = all_binary_preds[:, i]  # Binary predictions
        pred_df[f"{label}_prob"] = all_probs[:, i]    # Probabilities
    
    # Save to CSV
    pred_df.to_csv(file_name, index=False)
    print(f"Predictions saved to {file_name}")
    
    return all_probs, all_targets, pred_df

# Function to analyze predictions by demographic subgroups
def analyze_by_subgroups(predictions_df, label_columns, demographic_cols=None):
    """
    Analyze model performance across different demographic subgroups
    
    Args:
        predictions_df: DataFrame with predictions and demographic info
        label_columns: List of disease labels
        demographic_cols: List of demographic columns to analyze by
                      (default: ['gender', 'race', 'insurance', 'anchor_age'])
    
    Returns:
        results_dict: Dictionary with performance metrics by subgroup
    """
    if demographic_cols is None:
        demographic_cols = ['gender', 'race', 'insurance', 'anchor_age']
    
    results_dict = {}
    
    # Analyze each subgroup separately
    for group_col in demographic_cols:
        if group_col not in predictions_df.columns:
            continue
            
        # Get unique values for this subgroup
        subgroups = predictions_df[group_col].dropna().unique()
        
        group_results = {}
        for subgroup in subgroups:
            # Filter for this subgroup
            subgroup_df = predictions_df[predictions_df[group_col] == subgroup]
            
            if len(subgroup_df) < 10:  # Skip if too few samples
                continue
                
            # Calculate metrics for each disease label
            disease_metrics = {}
            for label in label_columns:
                true_col = f"{label}_true"
                pred_col = f"{label}"
                prob_col = f"{label}_prob"
                
                # Skip if column doesn't exist
                if true_col not in subgroup_df.columns or pred_col not in subgroup_df.columns:
                    continue
                
                # Extract ground truth and predictions
                y_true = subgroup_df[true_col].values
                y_pred = subgroup_df[pred_col].values
                y_prob = subgroup_df[prob_col].values if prob_col in subgroup_df.columns else None
                
                # Skip if no positive examples
                if sum(y_true) == 0:
                    continue
                    
                # Calculate metrics
                accuracy = accuracy_score(y_true, y_pred)
                precision = precision_score(y_true, y_pred, zero_division=0)
                recall = recall_score(y_true, y_pred, zero_division=0)
                f1 = f1_score(y_true, y_pred, zero_division=0)
                
                # Calculate AUC if we have probabilities
                auc = None
                if y_prob is not None:
                    try:
                        auc = roc_auc_score(y_true, y_prob)
                    except:
                        pass
                
                # Confusion matrix elements for TPR/FPR analysis
                tn = sum((y_true == 0) & (y_pred == 0))
                fp = sum((y_true == 0) & (y_pred == 1))
                fn = sum((y_true == 1) & (y_pred == 0))
                tp = sum((y_true == 1) & (y_pred == 1))
                
                # Calculate TPR and FPR
                tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
                fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
                
                # Store results
                disease_metrics[label] = {
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1,
                    'auc': auc,
                    'tpr': tpr,
                    'fpr': fpr,
                    'true_positives': tp,
                    'false_positives': fp,
                    'true_negatives': tn,
                    'false_negatives': fn,
                    'sample_count': len(y_true),
                    'positive_count': sum(y_true)
                }
            
            group_results[subgroup] = {
                'sample_count': len(subgroup_df),
                'disease_metrics': disease_metrics
            }
        
        results_dict[group_col] = group_results
    
    return results_dict

# Function to create visualizations of subgroup performance
def visualize_subgroup_performance(subgroup_results, metric='f1', output_dir='subgroup_analysis'):
    """
    Create visualizations comparing model performance across subgroups
    
    Args:
        subgroup_results: Results dictionary from analyze_by_subgroups
        metric: Which metric to visualize ('accuracy', 'precision', 'recall', 'f1', 'tpr', 'fpr')
        output_dir: Directory to save plots
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for group_name, group_data in subgroup_results.items():
        # Get all diseases across all subgroups
        all_diseases = set()
        for subgroup, subgroup_data in group_data.items():
            all_diseases.update(subgroup_data['disease_metrics'].keys())
        
        # Create a dataframe for plotting
        plot_data = []
        for subgroup, subgroup_data in group_data.items():
            for disease, metrics in subgroup_data['disease_metrics'].items():
                if metric in metrics:
                    plot_data.append({
                        'Subgroup': subgroup,
                        'Disease': disease,
                        metric: metrics[metric],
                        'Sample Count': metrics['sample_count']
                    })
        
        if not plot_data:
            continue
            
        plot_df = pd.DataFrame(plot_data)
        
        # Create heatmap
        plt.figure(figsize=(12, 8))
        pivot_table = plot_df.pivot_table(
            values=metric, 
            index='Disease', 
            columns='Subgroup'
        )
        
        sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', fmt='.2f', linewidths=.5)
        plt.title(f'{metric.capitalize()} by {group_name} Subgroups')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group_name}_{metric}_heatmap.png")
        plt.close()
        
        # Create grouped bar plot
        plt.figure(figsize=(14, 10))
        sns.barplot(x='Disease', y=metric, hue='Subgroup', data=plot_df)
        plt.title(f'{metric.capitalize()} by Disease and {group_name}')
        plt.xticks(rotation=45, ha='right')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group_name}_{metric}_barplot.png")
        plt.close()

# Function to analyze TPR and disparities by subgroup
# Function to analyze TPR fairness (fixed)
def analyze_tpr_fairness(predictions_df, label_columns, demographic_cols=None, significance_threshold=0.05):
    """
    Perform statistical analysis on TPR disparities to identify significant differences
    
    Args:
        predictions_df: DataFrame with predictions and demographic info
        label_columns: List of disease labels
        demographic_cols: List of demographic columns to analyze by
        significance_threshold: p-value threshold for statistical significance
        
    Returns:
        DataFrame with TPR disparities by demographic group
    """
    from scipy.stats import fisher_exact
    
    if demographic_cols is None:
        demographic_cols = ['gender', 'race', 'insurance', 'anchor_age']
    
    demographic_cols = [g for g in demographic_cols if g in predictions_df.columns]
    
    result_data = []
    
    for group in demographic_cols:
        # Get the majority subgroup as reference
        if group not in predictions_df.columns or predictions_df[group].isnull().all():
            continue
            
        reference = predictions_df[group].value_counts().index[0]
        
        for label in label_columns:
            true_col = f"{label}_true"
            pred_col = f"{label}"
            
            # Skip if columns don't exist
            if true_col not in predictions_df.columns or pred_col not in predictions_df.columns:
                continue
            
            # Get reference subgroup contingency table 
            ref_df = predictions_df[predictions_df[group] == reference]
            ref_tp = np.sum((ref_df[true_col] == 1) & (ref_df[pred_col] == 1))
            ref_fn = np.sum((ref_df[true_col] == 1) & (ref_df[pred_col] == 0))
            ref_tpr = ref_tp / (ref_tp + ref_fn) if (ref_tp + ref_fn) > 0 else 0
            
            # Compare with other subgroups
            for subgroup in predictions_df[group].dropna().unique():
                if subgroup == reference:
                    continue
                
                subgroup_df = predictions_df[predictions_df[group] == subgroup]
                
                # Skip if too few samples
                if len(subgroup_df) < 20:
                    continue
                
                # Calculate TPR
                sg_tp = np.sum((subgroup_df[true_col] == 1) & (subgroup_df[pred_col] == 1))
                sg_fn = np.sum((subgroup_df[true_col] == 1) & (subgroup_df[pred_col] == 0))
                sg_tpr = sg_tp / (sg_tp + sg_fn) if (sg_tp + sg_fn) > 0 else 0
                
                # Skip if no positive cases
                if (ref_tp + ref_fn) == 0 or (sg_tp + sg_fn) == 0:
                    continue
                
                # Statistical test (Fisher's exact test on the contingency table)
                contingency = np.array([[ref_tp, ref_fn], [sg_tp, sg_fn]])
                
                try:
                    odds_ratio, p_value = fisher_exact(contingency)
                    significant = p_value < significance_threshold
                except:
                    # If statistical test fails, skip
                    continue
                
                tpr_disparity = sg_tpr - ref_tpr
                
                result_data.append({
                    'Demographic_Group': group,
                    'Reference': reference,
                    'Subgroup': subgroup,
                    'Disease': label,
                    'Reference_TPR': ref_tpr,
                    'Subgroup_TPR': sg_tpr,
                    'TPR_Disparity': tpr_disparity,
                    'P_Value': p_value,
                    'Significant': significant,
                    'Reference_Positive_Count': ref_tp + ref_fn,
                    'Subgroup_Positive_Count': sg_tp + sg_fn
                })
    
    # Create DataFrame and ensure 'Significant' column exists
    if not result_data:
        # Return empty DataFrame with all required columns
        return pd.DataFrame(columns=[
            'Demographic_Group', 'Reference', 'Subgroup', 'Disease',
            'Reference_TPR', 'Subgroup_TPR', 'TPR_Disparity', 'P_Value',
            'Significant', 'Reference_Positive_Count', 'Subgroup_Positive_Count'
        ])
    
    result_df = pd.DataFrame(result_data)
    
    # Ensure Significant column exists (should already be there but just to be safe)
    if 'Significant' not in result_df.columns:
        result_df['Significant'] = result_df['P_Value'] < significance_threshold
    
    return result_df
                


# Function to plot TPR and disparities by subgroup
def plot_subgroup_tpr_and_disparities(predictions_df, label_columns, output_dir='tpr_analysis'):
    """
    Create plots showing:
    1. TPR (recall) for each subgroup within each demographic group
    2. Disparity from overall TPR for each subgroup
    
    Args:
        predictions_df: DataFrame with predictions and demographic data
        label_columns: List of disease labels
        output_dir: Directory to save output plots
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Define demographic groups to analyze
    demographic_groups = ['gender', 'race', 'insurance', 'anchor_age']
    demographic_groups = [g for g in demographic_groups if g in predictions_df.columns]
    
    # First, calculate overall TPR for each disease
    overall_tpr = {}
    for label in label_columns:
        true_col = f"{label}_true"
        pred_col = f"{label}"
        
        # Skip if columns don't exist
        if true_col not in predictions_df.columns or pred_col not in predictions_df.columns:
            continue
        
        # Calculate overall TPR (recall)
        y_true = predictions_df[true_col].values
        y_pred = predictions_df[pred_col].values
        
        # Skip if no positive cases
        if sum(y_true) == 0:
            continue
            
        overall_tpr[label] = recall_score(y_true, y_pred, zero_division=0)
    
    # Now calculate TPR for each subgroup and the disparity
    tpr_data = []
    
    for group in demographic_groups:
        for label in overall_tpr.keys():
            true_col = f"{label}_true"
            pred_col = f"{label}"
            
            # Calculate TPR for each subgroup in this demographic group
            for subgroup in predictions_df[group].dropna().unique():
                subgroup_df = predictions_df[predictions_df[group] == subgroup]
                
                # Skip if too few samples
                if len(subgroup_df) < 10:
                    continue
                
                y_true = subgroup_df[true_col].values
                y_pred = subgroup_df[pred_col].values
                
                # Skip if no positive cases in this subgroup
                if sum(y_true) == 0:
                    continue
                
                # Calculate TPR for this subgroup
                subgroup_tpr = recall_score(y_true, y_pred, zero_division=0)
                
                # Calculate disparity from overall TPR
                tpr_disparity = subgroup_tpr - overall_tpr[label]
                
                # Add to data collection
                tpr_data.append({
                    'Demographic_Group': group,
                    'Subgroup': subgroup,
                    'Disease': label,
                    'TPR': subgroup_tpr,
                    'Overall_TPR': overall_tpr[label],
                    'TPR_Disparity': tpr_disparity,
                    'Sample_Count': len(y_true),
                    'Positive_Count': sum(y_true)
                })
    
    # Convert to DataFrame
    tpr_df = pd.DataFrame(tpr_data)
    
    # Save the data
    tpr_df.to_csv(f"{output_dir}/tpr_by_subgroup.csv", index=False)
    
    # Create plots for each demographic group
    for group in demographic_groups:
        group_data = tpr_df[tpr_df['Demographic_Group'] == group]
        
        if len(group_data) == 0:
            continue
        
        # Plot 1: TPR by subgroup for each disease
        plt.figure(figsize=(14, 8))
        sns.barplot(x='Disease', y='TPR', hue='Subgroup', data=group_data)
        plt.title(f'True Positive Rate by {group} Subgroup')
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 1)  # TPR ranges from 0 to 1
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group}_tpr_by_subgroup.png")
        plt.close()
        
        # Plot 2: TPR Disparity from overall TPR
        plt.figure(figsize=(14, 8))
        sns.barplot(x='Disease', y='TPR_Disparity', hue='Subgroup', data=group_data)
        plt.title(f'TPR Disparity from Overall by {group} Subgroup')
        plt.xticks(rotation=45, ha='right')
        plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)  # Add a line at y=0
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{group}_tpr_disparity.png")
        plt.close()
        
        # Plot 3: Heatmap of TPR by subgroup and disease
        if len(group_data['Subgroup'].unique()) > 1 and len(group_data['Disease'].unique()) > 1:
            plt.figure(figsize=(12, 8))
            pivot_tpr = group_data.pivot_table(
                values='TPR', 
                index='Disease', 
                columns='Subgroup'
            )
            sns.heatmap(pivot_tpr, annot=True, cmap='YlGnBu', fmt='.2f', linewidths=.5)
            plt.title(f'TPR Heatmap by {group} Subgroup')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/{group}_tpr_heatmap.png")
            plt.close()
            
            # Plot 4: Heatmap of TPR disparities
            plt.figure(figsize=(12, 8))
            pivot_disparity = group_data.pivot_table(
                values='TPR_Disparity', 
                index='Disease', 
                columns='Subgroup'
            )
            # Use a diverging colormap centered at 0
            sns.heatmap(pivot_disparity, annot=True, cmap='RdBu_r', fmt='.2f', 
                        linewidths=.5, center=0)
            plt.title(f'TPR Disparity Heatmap by {group} Subgroup')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/{group}_tpr_disparity_heatmap.png")
            plt.close()

# Function to analyze fairness metrics across subgroups
def analyze_fairness_metrics(predictions_df, label_columns, demographic_cols=None):
    """
    Calculate comprehensive fairness metrics across demographic groups
    
    Args:
        predictions_df: DataFrame with predictions and demographic info
        label_columns: List of disease labels
        demographic_cols: List of demographic columns to analyze
        
    Returns:
        DataFrame with fairness metrics
    """
    if demographic_cols is None:
        demographic_cols = ['gender', 'race', 'insurance', 'anchor_age']
    
    demographic_cols = [g for g in demographic_cols if g in predictions_df.columns]
    
    fairness_metrics = []
    
    for group in demographic_cols:
        # Get subgroups
        subgroups = predictions_df[group].dropna().unique()
        
        # Get the majority subgroup as reference
        reference = predictions_df[group].value_counts().index[0]
        
        for label in label_columns:
            true_col = f"{label}_true"
            pred_col = f"{label}"
            
            # Skip if columns don't exist
            if true_col not in predictions_df.columns or pred_col not in predictions_df.columns:
                continue
            
            # Calculate overall metrics for this disease
            y_true_all = predictions_df[true_col].values
            y_pred_all = predictions_df[pred_col].values
            
            # Skip if no positive cases
            if sum(y_true_all) == 0:
                continue
                
            # Calculate TPR and FPR for each subgroup
            for subgroup in subgroups:
                subgroup_df = predictions_df[predictions_df[group] == subgroup]
                
                # Skip if too few samples
                if len(subgroup_df) < 20:
                    continue
                
                y_true = subgroup_df[true_col].values
                y_pred = subgroup_df[pred_col].values
                
                # Skip if no positive or negative cases
                if sum(y_true) == 0 or sum(y_true) == len(y_true):
                    continue
                
                # Calculate metrics
                accuracy = accuracy_score(y_true, y_pred)
                precision = precision_score(y_true, y_pred, zero_division=0)
                recall = recall_score(y_true, y_pred, zero_division=0)  # Same as TPR
                f1 = f1_score(y_true, y_pred, zero_division=0)
                
                # Confusion matrix elements
                tp = sum((y_true == 1) & (y_pred == 1))
                fp = sum((y_true == 0) & (y_pred == 1))
                tn = sum((y_true == 0) & (y_pred == 0))
                fn = sum((y_true == 1) & (y_pred == 0))
                
                # Fairness metrics
                tpr = recall  # True Positive Rate = Recall
                fpr = fp / (fp + tn) if (fp + tn) > 0 else 0  # False Positive Rate
                fnr = fn / (tp + fn) if (tp + fn) > 0 else 0  # False Negative Rate
                tnr = tn / (tn + fp) if (tn + fp) > 0 else 0  # True Negative Rate
                
                # Calculate prevalence (percentage of positive cases)
                prevalence = sum(y_true) / len(y_true)
                
                # Disparate impact = (% predicted positive for protected group) / (% predicted positive for reference group)
                # We'll calculate this at group level after collecting data for all subgroups
                predicted_positive_rate = sum(y_pred) / len(y_pred)
                
                # Add to results
                fairness_metrics.append({
                    'Demographic_Group': group,
                    'Subgroup': subgroup,
                    'Is_Reference': subgroup == reference,
                    'Disease': label,
                    'Accuracy': accuracy,
                    'Precision': precision,
                    'Recall': recall,
                    'F1': f1,
                    'TPR': tpr,
                    'FPR': fpr,
                    'TNR': tnr,
                    'FNR': fnr,
                    'Prevalence': prevalence,
                    'Predicted_Positive_Rate': predicted_positive_rate,
                    'Sample_Count': len(y_true),
                    'Positive_Count': sum(y_true),
                    'Predicted_Positive_Count': sum(y_pred)
                })
    
    fairness_df = pd.DataFrame(fairness_metrics)
    
    # Calculate disparate impact and equalized odds for each (demographic_group, disease) pair
    disparity_results = []
    
    for group in demographic_cols:
        for label in label_columns:
            group_disease_df = fairness_df[(fairness_df['Demographic_Group'] == group) & 
                                           (fairness_df['Disease'] == label)]
            
            if len(group_disease_df) <= 1:  # Need at least two subgroups to compare
                continue
                
            reference_row = group_disease_df[group_disease_df['Is_Reference'] == True]
            if len(reference_row) == 0:
                continue
                
            reference_tpr = reference_row['TPR'].values[0]
            reference_fpr = reference_row['FPR'].values[0]
            reference_ppr = reference_row['Predicted_Positive_Rate'].values[0]
            
            for idx, row in group_disease_df[group_disease_df['Is_Reference'] == False].iterrows():
                # Calculate disparity metrics
                tpr_disparity = row['TPR'] - reference_tpr
                fpr_disparity = row['FPR'] - reference_fpr
                
                # Disparate impact: ratio of predicted positive rates
                disparate_impact = row['Predicted_Positive_Rate'] / reference_ppr if reference_ppr > 0 else 0
                
                # Equal opportunity difference (difference in TPR)
                equal_opportunity_diff = tpr_disparity
                
                # Equalized odds violation (max of abs differences in TPR and FPR)
                equalized_odds_violation = max(abs(tpr_disparity), abs(fpr_disparity))
                
                disparity_results.append({
                    'Demographic_Group': group,
                    'Reference': reference_row['Subgroup'].values[0],
                    'Subgroup': row['Subgroup'],
                    'Disease': label,
                    'TPR_Disparity': tpr_disparity,
                    'FPR_Disparity': fpr_disparity,
                    'Disparate_Impact': disparate_impact,
                    'Equal_Opportunity_Diff': equal_opportunity_diff,
                    'Equalized_Odds_Violation': equalized_odds_violation,
                    'Reference_Sample_Count': reference_row['Sample_Count'].values[0],
                    'Subgroup_Sample_Count': row['Sample_Count']
                })
    
    disparity_df = pd.DataFrame(disparity_results)
    
    return fairness_df, disparity_df


In [5]:
# Main training function with fairness analysis
def train_and_evaluate(train_loader, val_loader, test_loader, label_columns, num_epochs=25):
    """
    Complete pipeline for training, evaluation, and subgroup analysis
    
    Args:
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        test_loader: DataLoader for test data
        label_columns: List of disease labels
        num_epochs: Number of training epochs
    """
    # Create output directories
    os.makedirs('predictions', exist_ok=True)
    os.makedirs('subgroup_analysis', exist_ok=True)
    os.makedirs('tpr_analysis', exist_ok=True)
    os.makedirs('fairness_analysis', exist_ok=True)
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    model = ChestXrayClassifier(input_dim=1376, output_dim=len(label_columns))
    model.to(device)
    
    # Initialize loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    
    # Training setup
    train_losses = []
    val_losses = []
    val_aucs = []
    best_val_auc = 0
    best_model_path = 'best_model.pth'
    
    # Training loop
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            # Move data to device
            inputs = batch['embedding'].to(device)
            targets = batch['labels'].to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item() * inputs.size(0)
            progress_bar.set_postfix({'loss': loss.item()})
        
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        all_outputs = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                inputs = batch['embedding'].to(device)
                targets = batch['labels'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                all_outputs.append(outputs.cpu().numpy())
                all_targets.append(targets.cpu().numpy())
        
        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Calculate AUC for each label
        all_outputs = np.vstack(all_outputs)
        all_targets = np.vstack(all_targets)
        all_probs = 1 / (1 + np.exp(-all_outputs))  # sigmoid
        
        aucs = {}
        for i, label in enumerate(label_columns):
            if sum(all_targets[:, i]) > 0:  # Only if there are positive examples
                aucs[label] = roc_auc_score(all_targets[:, i], all_probs[:, i])
        
        mean_auc = np.mean(list(aucs.values()))
        val_aucs.append(mean_auc)
        
        # Update learning rate based on validation loss
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        print(f"  Mean AUC: {mean_auc:.4f}")
        
        # Save best model
        if mean_auc > best_val_auc:
            best_val_auc = mean_auc
            torch.save(model.state_dict(), best_model_path)
            print(f"  Saved new best model with AUC: {mean_auc:.4f}")
        
        # Save predictions at final epoch
        if epoch == num_epochs - 1:
            print("\nGenerating final predictions...")
            
            # Load best model
            model.load_state_dict(torch.load(best_model_path))
            
            # Generate predictions
            train_probs, train_targets, train_pred_df = save_predictions_to_csv(
                train_loader, model, device, label_columns, 
                'predictions/train_predictions_final.csv'
            )
            
            val_probs, val_targets, val_pred_df = save_predictions_to_csv(
                val_loader, model, device, label_columns, 
                'predictions/val_predictions_final.csv'
            )
            
            test_probs, test_targets, test_pred_df = save_predictions_to_csv(
                test_loader, model, device, label_columns, 
                'predictions/test_predictions_final.csv'
            )
            
            # Perform TPR subgroup analysis on test predictions
            print("\nPerforming TPR subgroup analysis...")
            plot_subgroup_tpr_and_disparities(test_pred_df, label_columns)
            
            # Perform statistical analysis of TPR disparities
            # Perform statistical analysis of TPR disparities
            print("Analyzing TPR disparities across subgroups...")
            tpr_disparities = analyze_tpr_fairness(test_pred_df, label_columns)
            tpr_disparities.to_csv('tpr_analysis/significant_tpr_disparities.csv', index=False)
            
            # Print significant disparities
            significant_disparities = tpr_disparities[tpr_disparities['Significant']]
            
            # With this:
            # Perform statistical analysis of TPR disparities
            print("Analyzing TPR disparities across subgroups...")
            tpr_disparities = analyze_tpr_fairness(test_pred_df, label_columns)
            tpr_disparities.to_csv('tpr_analysis/significant_tpr_disparities.csv', index=False)

            # Print significant disparities
            if not tpr_disparities.empty and 'Significant' in tpr_disparities.columns:
                significant_disparities = tpr_disparities[tpr_disparities['Significant']]
                if len(significant_disparities) > 0:
                    print("\nSignificant TPR disparities found:")
                    for _, row in significant_disparities.sort_values('TPR_Disparity', ascending=False).head(5).iterrows():
                        print(f"  {row['Disease']} - {row['Demographic_Group']}: {row['Reference']} vs {row['Subgroup']}")
                        print(f"    TPR disparity: {row['TPR_Disparity']:.4f} (p-value: {row['P_Value']:.4f})")
            else:
                print("No significant TPR disparities found or not enough data for analysis")
            # Perform comprehensive fairness analysis
            print("\nPerforming comprehensive fairness analysis...")
            fairness_metrics, disparity_metrics = analyze_fairness_metrics(test_pred_df, label_columns)
            
            # Save fairness metrics
            fairness_metrics.to_csv('fairness_analysis/fairness_metrics.csv', index=False)
            disparity_metrics.to_csv('fairness_analysis/disparity_metrics.csv', index=False)
            
            # Generate subgroup performance visualizations
            print("Creating subgroup performance visualizations...")
            demographic_cols = ['gender', 'race', 'insurance', 'anchor_age']
            demographic_cols = [g for g in demographic_cols if g in test_pred_df.columns]
            
            subgroup_results = analyze_by_subgroups(test_pred_df, label_columns, demographic_cols)
            
            for metric in ['f1', 'recall', 'precision', 'accuracy', 'tpr', 'fpr']:
                visualize_subgroup_performance(subgroup_results, metric=metric)
                
            # Visualize fairness metrics
            print("Creating fairness metric visualizations...")
            for fairness_metric in ['TPR_Disparity', 'FPR_Disparity', 'Equal_Opportunity_Diff', 'Equalized_Odds_Violation']:
                if fairness_metric in disparity_metrics.columns:
                    plt.figure(figsize=(12, 8))
                    sns.boxplot(x='Disease', y=fairness_metric, hue='Demographic_Group', data=disparity_metrics)
                    plt.title(f'{fairness_metric} by Disease and Demographic Group')
                    plt.xticks(rotation=45, ha='right')
                    plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
                    plt.tight_layout()
                    plt.savefig(f"fairness_analysis/{fairness_metric}_boxplot.png")
                    plt.close()
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(val_aucs, label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.show()
    
    # Final evaluation
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    test_outputs = []
    test_targets_list = []
    
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['embedding'].to(device)
            targets = batch['labels'].to(device)
            
            outputs = model(inputs)
            test_outputs.append(outputs.cpu().numpy())
            test_targets_list.append(targets.cpu().numpy())
    
    test_outputs = np.vstack(test_outputs)
    test_targets = np.vstack(test_targets_list)
    test_probs = 1 / (1 + np.exp(-test_outputs))
    test_preds = (test_probs >= 0.5).astype(int)
    
    # Calculate final metrics
    metrics_df = pd.DataFrame(columns=['Label', 'AUC', 'Accuracy', 'Precision', 'Recall', 'F1'])
    
    print("\nFinal metrics by condition:")
    for i, label in enumerate(label_columns):
        # Skip if no positive examples
        if sum(test_targets[:, i]) > 0:
            auc = roc_auc_score(test_targets[:, i], test_probs[:, i])
            accuracy = accuracy_score(test_targets[:, i], test_preds[:, i])
            precision = precision_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            recall = recall_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            f1 = f1_score(test_targets[:, i], test_preds[:, i], zero_division=0)
            
            print(f"{label}:")
            print(f"  AUC: {auc:.4f}")
            print(f"  Accuracy: {accuracy:.4f}")
            print(f"  Precision: {precision:.4f}")
            print(f"  Recall: {recall:.4f}")
            print(f"  F1: {f1:.4f}")
            
            # Add to metrics dataframe
            metrics_df = pd.concat([
                metrics_df, 
                pd.DataFrame({
                    'Label': [label],
                    'AUC': [auc],
                    'Accuracy': [accuracy],
                    'Precision': [precision],
                    'Recall': [recall],
                    'F1': [f1]
                })
            ], ignore_index=True)
    
    print(f"\nOverall Mean AUC: {metrics_df['AUC'].mean():.4f}")
    metrics_df.to_csv('predictions/metrics_summary.csv', index=False)
    
    # Print fairness summary
    print("\nFairness Analysis Summary:")
    
    # Get the largest disparities
    if len(disparity_metrics) > 0:
        worst_disparities = disparity_metrics.nlargest(5, 'TPR_Disparity')
        print("\nLargest TPR Disparities (Recall differences):")
        for _, row in worst_disparities.iterrows():
            print(f"  {row['Disease']} - {row['Demographic_Group']}: {row['Reference']} vs {row['Subgroup']}")
            print(f"    TPR disparity: {row['TPR_Disparity']:.4f}")
        
        if 'Equalized_Odds_Violation' in disparity_metrics.columns:
            worst_eq_odds = disparity_metrics.nlargest(5, 'Equalized_Odds_Violation')
            print("\nLargest Equalized Odds Violations:")
            for _, row in worst_eq_odds.iterrows():
                print(f"  {row['Disease']} - {row['Demographic_Group']}: {row['Reference']} vs {row['Subgroup']}")
                print(f"    Violation: {row['Equalized_Odds_Violation']:.4f}")
    
    return model, metrics_df, fairness_metrics, disparity_metrics

In [6]:

# Define label columns
label_columns = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 
                'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 
                'Pleural Other', 'Fracture', 'Support Devices', 'No Finding']

# Train and evaluate the model with fairness analysis
model, metrics_df, fairness_metrics, disparity_metrics = train_and_evaluate(
    train_loader, val_loader, test_loader, label_columns, num_epochs=10
)

Using device: cuda




Epoch 1/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 1/10:
  Train Loss: 0.2585
  Val Loss: 0.2507
  Mean AUC: 0.8184
  Saved new best model with AUC: 0.8184


Epoch 2/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 2/10:
  Train Loss: 0.2514
  Val Loss: 0.2487
  Mean AUC: 0.8242
  Saved new best model with AUC: 0.8242


Epoch 3/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 3/10:
  Train Loss: 0.2495
  Val Loss: 0.2484
  Mean AUC: 0.8260
  Saved new best model with AUC: 0.8260


Epoch 4/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 4/10:
  Train Loss: 0.2481
  Val Loss: 0.2473
  Mean AUC: 0.8291
  Saved new best model with AUC: 0.8291


Epoch 5/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 5/10:
  Train Loss: 0.2470
  Val Loss: 0.2474
  Mean AUC: 0.8276


Epoch 6/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 6/10:
  Train Loss: 0.2459
  Val Loss: 0.2464
  Mean AUC: 0.8310
  Saved new best model with AUC: 0.8310


Epoch 7/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 7/10:
  Train Loss: 0.2448
  Val Loss: 0.2465
  Mean AUC: 0.8316
  Saved new best model with AUC: 0.8316


Epoch 8/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 8/10:
  Train Loss: 0.2440
  Val Loss: 0.2458
  Mean AUC: 0.8321
  Saved new best model with AUC: 0.8321


Epoch 9/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 9/10:
  Train Loss: 0.2431
  Val Loss: 0.2458
  Mean AUC: 0.8336
  Saved new best model with AUC: 0.8336


Epoch 10/10:   0%|          | 0/1458 [00:00<?, ?it/s]

Epoch 10/10:
  Train Loss: 0.2422
  Val Loss: 0.2460
  Mean AUC: 0.8321

Generating final predictions...


  model.load_state_dict(torch.load(best_model_path))


Generating predictions:   0%|          | 0/1458 [00:00<?, ?it/s]

Predictions saved to predictions/train_predictions_final.csv


Generating predictions:   0%|          | 0/162 [00:00<?, ?it/s]

Predictions saved to predictions/val_predictions_final.csv


Generating predictions:   0%|          | 0/169 [00:00<?, ?it/s]

Predictions saved to predictions/test_predictions_final.csv

Performing TPR subgroup analysis...
Analyzing TPR disparities across subgroups...


KeyError: 'Significant'

In [None]:
import os
import zipfile


new_zip_path = "/kaggle/working/optimized_files.zip"
items_to_zip = [
    "chest_xray_model.pth",
    "predictions",
    "subgroup_analysis", 
    "tpr_analysis",
    "training_curves.png"
]

# Use a memory-efficient approach
with zipfile.ZipFile(new_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add files one by one
    for item in items_to_zip:
        item_path = os.path.join("/kaggle/working", item)
        
        # Add file
        if os.path.isfile(item_path):
            zipf.write(item_path, arcname=item)
            print(f"Added file: {item}")
        
        # Add directory contents
        elif os.path.isdir(item_path):
            # Get base name for path calculations
            base_name = os.path.basename(item_path)
            
            # Walk through directory
            for dir_path, _, files in os.walk(item_path):
                # Skip empty directories
                if not files:
                    continue
                    
                # Process each file
                for file in files:
                    file_path = os.path.join(dir_path, file)
                    # Calculate relative path for the archive
                    arc_path = os.path.join(base_name, os.path.relpath(file_path, item_path))
                    # Add to zip with correct path structure
                    zipf.write(file_path, arcname=arc_path)

print(f"Created optimized zip file: optimized_files.zip")