<a href="https://colab.research.google.com/github/LogicOber/LogicOber/blob/main/v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import copy
import os
import math
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directory for model checkpoints
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('results', exist_ok=True)

# Sensitivity-optimized Focal Loss with improved numerical stability
class SensitivityFocalLoss(nn.Module):
    """Focal Loss with increased penalty for false negatives"""
    def __init__(self, alpha=0.25, gamma=2.0, sensitivity_weight=2.0, epsilon=1e-7):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.sensitivity_weight = sensitivity_weight  # Controls penalty for false negatives
        self.epsilon = epsilon  # For numerical stability

    def forward(self, inputs, targets):
        # Clamp inputs to prevent extreme values
        inputs = torch.clamp(inputs, -50, 50)

        # Apply sigmoid manually with stable computation
        probs = torch.sigmoid(inputs)
        probs = torch.clamp(probs, self.epsilon, 1.0 - self.epsilon)

        # Calculate BCE loss
        BCE_loss = -targets * torch.log(probs) - (1 - targets) * torch.log(1 - probs)

        # Compute focal weights with clamping for stability
        pt = torch.exp(-torch.clamp(BCE_loss, 0, 50))

        # Apply weights with safeguards
        weights = torch.ones_like(targets)
        weights[targets > 0.5] = self.sensitivity_weight

        # Calculate focal loss with checks for NaN
        focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss * weights

        # Replace any NaN values
        if torch.isnan(focal_loss).any():
            print("Warning: NaN in loss, replacing with small value")
            focal_loss = torch.nan_to_num(focal_loss, nan=0.1)

        return focal_loss.mean()

def load_and_preprocess_data(file_path, debug=True):
    print("Loading data from", file_path)
    data = pd.read_csv(file_path)

    if debug:
        print(f"Data shape: {data.shape}")
        print(f"Columns: {data.columns.tolist()}")
        print("\nFirst few rows:")
        print(data.head())

    # Identify continuous features
    cont_features = ['creatinine', 'bicarbonate', 'chloride', 'glucose',
                    'magnesium', 'potassium', 'sodium', 'urea_nitrogen',
                    'hemoglobin', 'platelet_count', 'wbc_count', 'lactate',
                    'heart_rate', 'resp_rate', 'temperature', 'spo2',
                    'nbp_sys', 'nbp_dias', 'nbp_mean', 'gcs_total',
                    'urine_output_ml']

    # Convert gender to numeric efficiently
    data['gender_numeric'] = data['gender'].map({'M': 1, 'F': 0}).fillna(0.5)

    # Handle missing values and clip extreme values
    for col in cont_features:
        if col in data.columns:
            # Create missing indicator
            data[f'is_{col}_missing'] = data[col].isnull().astype(int)

            # Impute with median for missing values
            data[col] = data[col].fillna(data[col].median())

            # Clip extreme values
            q01 = data[col].quantile(0.01)
            q99 = data[col].quantile(0.99)
            data[col] = data[col].clip(q01, q99)

    # Create creatinine ratio based on KDIGO guidelines if baseline is available
    if 'baseline_creatinine' in data.columns and 'creatinine' in data.columns:
        data['baseline_creatinine'] = data['baseline_creatinine'].fillna(data['baseline_creatinine'].median())

        # Clip baseline creatinine as well
        q01 = data['baseline_creatinine'].quantile(0.01)
        q99 = data['baseline_creatinine'].quantile(0.99)
        data['baseline_creatinine'] = data['baseline_creatinine'].clip(q01, q99)

        # Avoid division by zero
        data['creatinine_ratio'] = data['creatinine'] / (data['baseline_creatinine'] + 1e-8)

        # Clip ratio to reasonable range
        data['creatinine_ratio'] = data['creatinine_ratio'].clip(0.1, 10.0)

        cont_features.append('creatinine_ratio')

        if debug:
            print("\nCreatinine ratio statistics:")
            print(data['creatinine_ratio'].describe())

    if debug:
        print("\nMissing values after imputation:")
        print(data[cont_features].isnull().sum())

        print("\nTarget distribution:")
        if 'aki_label' in data.columns:
            print(data['aki_label'].value_counts(normalize=True))

    return data, cont_features

def create_enhanced_sequences_optimized(df, cont_features, debug=True):
    print("Creating enhanced sequences with clinical features...")

    sequences = []
    static_features_list = []
    labels = []
    stay_ids = []

    # Pre-define feature lists to avoid repeated operations
    dynamic_features = cont_features.copy()

    # Add missing indicators only once
    missing_indicators = [f'is_{col}_missing' for col in cont_features if f'is_{col}_missing' in df.columns]
    dynamic_features.extend(missing_indicators)

    # Define static features
    static_features = ['age', 'gender_numeric', 'weight', 'height']
    if 'baseline_creatinine' in df.columns:
        static_features.append('baseline_creatinine')

    if debug:
        print(f"Dynamic features: {len(dynamic_features)} total")
        print(f"Static features: {static_features}")

    # Pre-compute features for the entire dataset to avoid per-patient calculation
    if 'creatinine' in df.columns:
        # Create a group key for efficient operations
        df['_group_key'] = df['stay_id'].astype(str)

        # Pre-sort the DataFrame
        df = df.sort_values(['stay_id', 'hour_from_icu'])

        # Pre-compute shifts and diffs for the entire dataset
        df['cr_delta'] = df.groupby('_group_key')['creatinine'].diff().fillna(0)
        # Clip delta to prevent extreme values
        df['cr_delta'] = df['cr_delta'].clip(-2.0, 2.0)  # Reasonable range for creatinine change

        df['cr_pct_change'] = df.groupby('_group_key')['creatinine'].pct_change().fillna(0) * 100
        df['cr_pct_change'] = df['cr_pct_change'].clip(-200, 200)  # Clip percentage change

        # Add additional temporal features
        df['cr_rolling_mean'] = df.groupby('_group_key')['creatinine'].transform(
            lambda x: x.rolling(window=min(3, len(x)), min_periods=1).mean()
        )

        df['cr_rolling_std'] = df.groupby('_group_key')['creatinine'].transform(
            lambda x: x.rolling(window=min(3, len(x)), min_periods=1).std()
        ).fillna(0)

        # Calculate acceleration (2nd derivative)
        df['cr_accel'] = df.groupby('_group_key')['cr_delta'].diff().fillna(0)
        df['cr_accel'] = df['cr_accel'].clip(-1.0, 1.0)  # Clip acceleration

        # Add these to the dynamic features list
        dynamic_features.extend(['cr_delta', 'cr_pct_change', 'cr_rolling_mean',
                              'cr_rolling_std', 'cr_accel'])

        if 'baseline_creatinine' in df.columns:
            # Compute KDIGO features across the entire dataset
            baseline_cr = df.groupby('_group_key')['baseline_creatinine'].transform('first')
            df['kdigo_abs_increase'] = (df['creatinine'] - baseline_cr >= 0.3).astype(float)
            df['kdigo_rel_increase'] = (df['creatinine'] / (baseline_cr + 1e-8) >= 1.5).astype(float)

            # Add BUN/Creatinine ratio if available
            if 'urea_nitrogen' in df.columns:
                df['bun_cr_ratio'] = df['urea_nitrogen'] / (df['creatinine'] + 1e-8)
                df['bun_cr_ratio'] = df['bun_cr_ratio'].clip(1.0, 100.0)  # Clip to reasonable range
                dynamic_features.append('bun_cr_ratio')

            # Add to dynamic features
            dynamic_features.extend(['kdigo_abs_increase', 'kdigo_rel_increase'])

    # Progress bar for sequence creation
    unique_stay_ids = df['stay_id'].unique()
    progress_bar = tqdm(unique_stay_ids, desc="Creating patient sequences")

    for stay_id in progress_bar:
        # Get data for this stay (already sorted)
        stay_data = df[df['stay_id'] == stay_id]

        # Skip if less than 2 time points
        if len(stay_data) < 2:
            continue

        # Extract dynamic features efficiently
        dynamic_matrix = []
        for _, row in stay_data.iterrows():
            # Create a vector of just the dynamic features we need
            feature_vector = [row[col] if col in row.index and not np.isnan(row[col]) else 0.0
                             for col in dynamic_features]

            # Add time information
            feature_vector.append(row['hour_from_icu'])
            if 'time_window' in row.index:
                feature_vector.append(row['time_window'])

            dynamic_matrix.append(feature_vector)

        # Convert directly to numpy array (once)
        dynamic_seq = np.array(dynamic_matrix, dtype=np.float32)

        # Replace any remaining NaN values with zeros
        dynamic_seq = np.nan_to_num(dynamic_seq, nan=0.0)

        # Extract static features (from first row - more efficient)
        first_row = stay_data.iloc[0]
        static_feat = np.array([first_row[col] if col in first_row.index and not np.isnan(first_row[col]) else 0.0
                               for col in static_features], dtype=np.float32)

        # Replace any NaN values in static features
        static_feat = np.nan_to_num(static_feat, nan=0.0)

        # Extract label directly
        if 'aki_label' in stay_data.columns:
            label = stay_data['aki_label'].iloc[-1]
        elif 'aki_48h' in stay_data.columns and 'aki_7day' in stay_data.columns:
            label = 1 if stay_data['aki_48h'].iloc[-1] == 1 or stay_data['aki_7day'].iloc[-1] == 1 else 0
        else:
            continue  # Skip if no label available

        # Ensure label is a simple integer, not numpy type
        label = int(label)

        # Append to lists
        sequences.append(dynamic_seq)
        static_features_list.append(static_feat)
        labels.append(label)
        stay_ids.append(stay_id)

        # Update progress bar
        progress_bar.set_postfix({'Sequences': len(sequences)})

    if debug and sequences:
        print(f"\nCreated {len(sequences)} sequences")
        print(f"Dynamic feature dimension: {sequences[0].shape}")
        print(f"Static feature dimension: {static_features_list[0].shape}")
        print(f"Label distribution: {np.bincount(labels)}")

    return sequences, static_features_list, labels, stay_ids, dynamic_features, static_features

class AKISequenceDataset(Dataset):
    def __init__(self, sequences, static_features, labels, transform=None):
        self.sequences = sequences
        self.static_features = static_features
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        static = self.static_features[idx]
        label = self.labels[idx]

        # Apply transformations if any
        if self.transform:
            seq = self.transform(seq)

        return (seq, static), label

def collate_fn(batch):
    # Separate sequences, static features and labels
    seqs = [item[0][0] for item in batch]
    static_feats = [item[0][1] for item in batch]
    labels = [item[1] for item in batch]

    # Get sequence lengths
    seq_lengths = torch.LongTensor([len(seq) for seq in seqs])

    # Find maximum sequence length
    max_len = max(seq_lengths).item()

    # Get feature dimensions
    feat_dim = seqs[0].shape[1]

    # Prepare padded sequences tensor
    padded_seqs = torch.zeros((len(seqs), max_len, feat_dim))

    # Fill padded tensor
    for i, seq in enumerate(seqs):
        end = seq_lengths[i]
        padded_seqs[i, :end] = torch.FloatTensor(seq[:end])

    # Convert to tensors
    static_feats = torch.FloatTensor(static_feats)
    labels = torch.FloatTensor(labels)

    return (padded_seqs, static_feats, seq_lengths), labels

class ImprovedBiLSTM(nn.Module):
    """Improved Bidirectional LSTM with robust numerical stability"""
    def __init__(self, dynamic_dim, static_dim, hidden_dim=128, n_layers=2, dropout=0.3):
        super().__init__()

        # Parameter initialization - use Xavier for better stability
        self.init_scale = 0.1

        # Bidirectional LSTM with stable initialization
        self.lstm = nn.LSTM(
            input_size=dynamic_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0,
            bidirectional=True
        )

        # Initialize LSTM weights with Xavier uniform for better stability
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=self.init_scale)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

        # Improved attention mechanism with small initialization
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

        # Initialize attention weights carefully
        for layer in self.attention:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=self.init_scale)
                nn.init.constant_(layer.bias, 0)

        # Static features encoder with batch normalization
        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout)
        )

        # Initialize static encoder weights
        nn.init.xavier_uniform_(self.static_encoder[0].weight, gain=self.init_scale)
        nn.init.constant_(self.static_encoder[0].bias, 0)

        # Clinical features encoder with stable initialization
        self.clinical_encoder = nn.Sequential(
            nn.Linear(3, hidden_dim//4),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim//4),
            nn.Dropout(dropout*0.5)
        )

        # Initialize clinical encoder weights
        nn.init.xavier_uniform_(self.clinical_encoder[0].weight, gain=self.init_scale)
        nn.init.constant_(self.clinical_encoder[0].bias, 0)

        # Feature fusion layer
        self.fusion = nn.Linear(hidden_dim*2 + hidden_dim + hidden_dim//4, hidden_dim)
        nn.init.xavier_uniform_(self.fusion.weight, gain=self.init_scale)
        nn.init.constant_(self.fusion.bias, 0)

        # Output layers
        self.fc1 = nn.Linear(hidden_dim, hidden_dim//2)
        self.bn1 = nn.BatchNorm1d(hidden_dim//2)
        self.dropout1 = nn.Dropout(dropout)

        self.fc2 = nn.Linear(hidden_dim//2, hidden_dim//4)
        self.bn2 = nn.BatchNorm1d(hidden_dim//4)
        self.dropout2 = nn.Dropout(dropout*0.5)

        self.fc3 = nn.Linear(hidden_dim//4, 1)

        # Initialize output layers
        nn.init.xavier_uniform_(self.fc1.weight, gain=self.init_scale)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.xavier_uniform_(self.fc2.weight, gain=self.init_scale)
        nn.init.constant_(self.fc2.bias, 0)
        nn.init.xavier_uniform_(self.fc3.weight, gain=self.init_scale)
        nn.init.constant_(self.fc3.bias, 0)

    def apply_attention(self, lstm_output, seq_lengths):
        batch_size, seq_len, hidden_dim = lstm_output.size()

        # Apply attention with stable masking approach
        attn_scores = self.attention(lstm_output.reshape(-1, hidden_dim)).view(batch_size, seq_len, 1)

        # Create mask for padded positions (0 for padding, 1 for actual data)
        mask = torch.zeros_like(attn_scores, device=lstm_output.device)
        for i, length in enumerate(seq_lengths):
            mask[i, :length] = 1

        # Apply mask by setting padding positions to large negative value
        attn_scores = attn_scores * mask + (mask - 1) * 1e9  # -1e9 for padding

        # Apply softmax with better numerical stability
        attn_weights = F.softmax(attn_scores, dim=1)

        # Apply attention to get context vector
        context = torch.bmm(lstm_output.transpose(1, 2), attn_weights)

        # Check for NaN and replace with zeros
        context = torch.nan_to_num(context, nan=0.0)

        return context.squeeze(2), attn_weights

    def forward(self, dynamic_seq, static_features, seq_lengths):
        batch_size = dynamic_seq.size(0)

        # Check and handle NaN values in input
        dynamic_seq = torch.nan_to_num(dynamic_seq, nan=0.0)
        static_features = torch.nan_to_num(static_features, nan=0.0)

        # Process dynamic features with bidirectional LSTM
        try:
            packed_sequence = nn.utils.rnn.pack_padded_sequence(
                dynamic_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
            )

            packed_output, _ = self.lstm(packed_sequence)
            lstm_output, _ = nn.utils.rnn.pad_packed_sequence(
                packed_output, batch_first=True
            )

            # Check for NaN in LSTM output
            lstm_output = torch.nan_to_num(lstm_output, nan=0.0)
        except Exception as e:
            print(f"Error in LSTM processing: {e}")
            # Provide fallback output if LSTM fails
            lstm_output = torch.zeros((batch_size, seq_lengths.max().item(), self.lstm.hidden_size * 2),
                                     device=dynamic_seq.device)

        # Apply attention with error handling
        try:
            dynamic_context, attention_weights = self.apply_attention(lstm_output, seq_lengths)
        except Exception as e:
            print(f"Error in attention mechanism: {e}")
            # Fallback to simple pooling if attention fails
            dynamic_context = torch.mean(lstm_output, dim=1)
            attention_weights = torch.ones((batch_size, lstm_output.size(1), 1),
                                          device=dynamic_seq.device) / lstm_output.size(1)

        # Process static features with batch normalization
        try:
            static_encoded = self.static_encoder(static_features)
            static_encoded = torch.nan_to_num(static_encoded, nan=0.0)
        except Exception as e:
            print(f"Error in static encoder: {e}")
            static_encoded = torch.zeros((batch_size, self.static_encoder[0].out_features),
                                        device=dynamic_seq.device)

        # Extract clinical features with robust implementation
        try:
            cr_idx = 0  # Creatinine index in feature vector
            clinical_features = torch.zeros(batch_size, 3, device=dynamic_seq.device)

            for i, length in enumerate(seq_lengths):
                # Safely extract creatinine values
                if length > 0:
                    # Last creatinine value
                    clinical_features[i, 0] = dynamic_seq[i, length-1, cr_idx]

                    # Extract sequence of creatinine values
                    cr_values = dynamic_seq[i, :length, cr_idx]

                    # Maximum creatinine value
                    clinical_features[i, 1] = torch.max(cr_values)

                    # Maximum increase in creatinine
                    if length > 1:
                        cr_diffs = cr_values[1:] - cr_values[:-1]
                        clinical_features[i, 2] = torch.max(cr_diffs)

            # Check for NaN values
            clinical_features = torch.nan_to_num(clinical_features, nan=0.0)

            # Process through clinical encoder
            clinical_encoded = self.clinical_encoder(clinical_features)
            clinical_encoded = torch.nan_to_num(clinical_encoded, nan=0.0)
        except Exception as e:
            print(f"Error in clinical feature extraction: {e}")
            clinical_encoded = torch.zeros((batch_size, self.clinical_encoder[0].out_features),
                                          device=dynamic_seq.device)

        # Combine all features with safe concatenation
        try:
            combined = torch.cat([dynamic_context, static_encoded, clinical_encoded], dim=1)
            combined = torch.nan_to_num(combined, nan=0.0)

            fused = F.relu(self.fusion(combined))
            fused = torch.nan_to_num(fused, nan=0.0)
        except Exception as e:
            print(f"Error in feature fusion: {e}")
            fused = torch.zeros((batch_size, self.fusion.out_features), device=dynamic_seq.device)

        # Output layers with robust implementation
        try:
            x = self.fc1(fused)
            x = self.bn1(x)
            x = F.relu(x)
            x = self.dropout1(x)

            x = self.fc2(x)
            x = self.bn2(x)
            x = F.relu(x)
            x = self.dropout2(x)

            output = self.fc3(x)
            output = torch.clamp(output, -50, 50)  # Prevent extreme output values
        except Exception as e:
            print(f"Error in output layers: {e}")
            output = torch.zeros((batch_size, 1), device=dynamic_seq.device)

        return output.squeeze(), attention_weights

def train_enhanced_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                        n_epochs=50, patience=15, device='cpu', debug=True, save_dir='checkpoints'):
    # Initialize tracking variables
    train_losses = []
    val_losses = []
    train_aucs = []
    val_aucs = []
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    best_auc = 0.0

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Training loop
    for epoch in range(n_epochs):
        # Training phase
        model.train()
        epoch_train_losses = []
        all_train_preds = []
        all_train_labels = []

        # Create progress bar for training batches
        train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs} [Train]")

        for ((dynamic, static, lengths), labels) in train_progress:
            try:
                # Check for NaN in input data
                if torch.isnan(dynamic).any() or torch.isnan(static).any():
                    print("Warning: NaN found in input data. Replacing with zeros.")
                    dynamic = torch.nan_to_num(dynamic, nan=0.0)
                    static = torch.nan_to_num(static, nan=0.0)

                # Move data to device
                dynamic = dynamic.to(device)
                static = static.to(device)
                lengths = lengths.to(device)
                labels = labels.to(device)

                # Forward pass
                optimizer.zero_grad()
                outputs, _ = model(dynamic, static, lengths)

                # Check for NaN in outputs
                if torch.isnan(outputs).any():
                    print("Warning: NaN in model outputs. Skipping batch.")
                    continue

                # Collect for AUC calculation
                probs = torch.sigmoid(torch.clamp(outputs, -50, 50))
                all_train_preds.extend(probs.detach().cpu().numpy())
                all_train_labels.extend(labels.cpu().numpy())

                # Calculate loss
                loss = criterion(outputs, labels)

                # Check for NaN in loss
                if torch.isnan(loss).any():
                    print("Warning: NaN in loss. Skipping batch.")
                    continue

                # Backward pass
                loss.backward()

                # Add gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Check for NaN in gradients
                has_nan_grad = False
                for param in model.parameters():
                    if param.grad is not None and torch.isnan(param.grad).any():
                        has_nan_grad = True
                        break

                if has_nan_grad:
                    print("Warning: NaN in gradients. Skipping parameter update.")
                    optimizer.zero_grad()  # Clear the gradients
                else:
                    optimizer.step()

                # Record loss
                epoch_train_losses.append(loss.item())

                # Update progress bar
                train_progress.set_postfix({
                    'loss': f"{loss.item():.4f}"
                })

            except Exception as e:
                print(f"Error in training: {e}")
                continue

        # Calculate epoch training loss
        if epoch_train_losses:
            epoch_train_loss = np.mean(epoch_train_losses)
            train_losses.append(epoch_train_loss)
        else:
            epoch_train_loss = float('inf')
            train_losses.append(float('inf'))

        # Calculate training AUC if we have predictions
        if len(all_train_preds) > 0 and len(all_train_labels) > 0:
            # Convert to numpy arrays and check for NaN
            all_train_preds = np.array(all_train_preds)
            all_train_labels = np.array(all_train_labels)

            if np.isnan(all_train_preds).any():
                print("Warning: NaN in predictions, replacing with 0.5")
                all_train_preds = np.nan_to_num(all_train_preds, nan=0.5)

            try:
                train_auc = roc_auc_score(all_train_labels, all_train_preds)
                train_aucs.append(train_auc)
            except Exception as e:
                print(f"Error calculating training AUC: {e}")
                train_auc = 0.5
                train_aucs.append(0.5)
        else:
            train_auc = 0.5
            train_aucs.append(0.5)

        # Validation phase
        model.eval()
        epoch_val_losses = []
        all_val_preds = []
        all_val_labels = []

        # Create progress bar for validation batches
        val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{n_epochs} [Valid]")

        with torch.no_grad():
            for ((dynamic, static, lengths), labels) in val_progress:
                try:
                    # Check and handle NaN values
                    dynamic = torch.nan_to_num(dynamic, nan=0.0)
                    static = torch.nan_to_num(static, nan=0.0)

                    # Move data to device
                    dynamic = dynamic.to(device)
                    static = static.to(device)
                    lengths = lengths.to(device)
                    labels = labels.to(device)

                    # Forward pass
                    outputs, _ = model(dynamic, static, lengths)

                    # Check for NaN in outputs
                    if torch.isnan(outputs).any():
                        print("Warning: NaN in validation outputs. Skipping batch.")
                        continue

                    # Calculate loss
                    loss = criterion(outputs, labels)

                    # Check for NaN in loss
                    if torch.isnan(loss).any():
                        print("Warning: NaN in validation loss. Skipping batch.")
                        continue

                    # Record predictions and loss
                    epoch_val_losses.append(loss.item())
                    probs = torch.sigmoid(torch.clamp(outputs, -50, 50))
                    all_val_preds.extend(probs.cpu().numpy())
                    all_val_labels.extend(labels.cpu().numpy())

                    # Update progress bar
                    val_progress.set_postfix({
                        'loss': f"{loss.item():.4f}"
                    })

                except Exception as e:
                    print(f"Error in validation: {e}")
                    continue

        # Calculate validation metrics
        if epoch_val_losses:
            epoch_val_loss = np.mean(epoch_val_losses)
            val_losses.append(epoch_val_loss)
        else:
            epoch_val_loss = float('inf')
            val_losses.append(float('inf'))

        # Calculate AUC if we have predictions
        if len(all_val_preds) > 0 and len(all_val_labels) > 0:
            # Convert to numpy arrays and check for NaN
            all_val_preds = np.array(all_val_preds)
            all_val_labels = np.array(all_val_labels)

            if np.isnan(all_val_preds).any():
                print("Warning: NaN in validation predictions, replacing with 0.5")
                all_val_preds = np.nan_to_num(all_val_preds, nan=0.5)

            try:
                val_auc = roc_auc_score(all_val_labels, all_val_preds)
                val_aucs.append(val_auc)
            except Exception as e:
                print(f"Error calculating validation AUC: {e}")
                val_auc = 0.5
                val_aucs.append(0.5)
        else:
            val_auc = 0.5
            val_aucs.append(0.5)

        # Print epoch summary
        print(f"Epoch {epoch+1}/{n_epochs} Results:")
        print(f"  Train Loss: {epoch_train_loss:.4f}")
        print(f"  Train AUC:  {train_auc:.4f}")
        print(f"  Valid Loss: {epoch_val_loss:.4f}")
        print(f"  Valid AUC:  {val_auc:.4f}")
        print("-" * 60)

        # Save checkpoint for every epoch
        checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'val_loss': epoch_val_loss,
            'val_auc': val_auc,
            'train_loss': epoch_train_loss,
            'train_auc': train_auc
        }, checkpoint_path)

        # Save best model based on validation AUC (primary) and loss (secondary)
        if val_auc > best_auc or (val_auc == best_auc and epoch_val_loss < best_val_loss):
            best_val_loss = epoch_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            best_auc = val_auc
            patience_counter = 0

            # Save best model
            best_model_path = os.path.join(save_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'val_loss': best_val_loss,
                'val_auc': best_auc,
                'train_loss': epoch_train_loss,
                'train_auc': train_auc
            }, best_model_path)
            print(f"Best model saved with Val AUC: {best_auc:.4f}, Val Loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1

        # Update learning rate scheduler
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(epoch_val_loss)
            else:
                scheduler.step()

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        # Memory management for GPU
        if device != 'cpu' and (epoch + 1) % 5 == 0:
            torch.cuda.empty_cache()

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Return training history
    history = {
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_auc': train_aucs,
        'val_auc': val_aucs,
        'best_auc': best_auc
    }

    return model, history

def evaluate_model(model, test_loader, device='cpu'):
    """Standard evaluation with robust error handling"""
    model.eval()
    all_preds = []
    all_labels = []
    all_attentions = []

    # Progress bar for test evaluation
    test_progress = tqdm(test_loader, desc="Evaluating model")

    with torch.no_grad():
        for ((dynamic, static, lengths), labels) in test_progress:
            try:
                # Handle NaN values
                dynamic = torch.nan_to_num(dynamic, nan=0.0)
                static = torch.nan_to_num(static, nan=0.0)

                # Move data to device
                dynamic = dynamic.to(device)
                static = static.to(device)
                lengths = lengths.to(device)

                # Get predictions
                outputs, attention_weights = model(dynamic, static, lengths)

                # Handle any NaN in outputs
                outputs = torch.nan_to_num(outputs, nan=0.0)

                # Apply sigmoid to get probabilities
                probs = torch.sigmoid(torch.clamp(outputs, -50, 50))

                # Store predictions, labels and attention weights
                all_preds.extend(probs.cpu().numpy())
                all_labels.extend(labels.numpy())

                # Store attention weights for analysis
                all_attentions.append(attention_weights.cpu().numpy())
            except Exception as e:
                print(f"Error in evaluation: {e}")
                continue

    # Convert to numpy arrays and check for NaN
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    if np.isnan(all_preds).any():
        print("Warning: NaN in test predictions, replacing with 0.5")
        all_preds = np.nan_to_num(all_preds, nan=0.5)

    return all_preds, all_labels, all_attentions

def find_optimal_threshold(y_true, y_pred):
    """Find threshold that maximizes F1 score"""
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)

    # Calculate F1 scores
    f1_scores = []
    for i in range(len(precision)):
        if i < len(thresholds):
            if recall[i] == 0 or precision[i] == 0:
                f1_scores.append(0)
            else:
                f1_scores.append(2 * precision[i] * recall[i] / (precision[i] + recall[i]))

    if f1_scores:
        # Find threshold with best F1 score
        best_idx = np.argmax(f1_scores)
        best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5
        best_f1 = f1_scores[best_idx]

        # Also find threshold with better sensitivity (recall >= 0.65)
        high_recall_idx = np.where(recall >= 0.65)[0]
        if len(high_recall_idx) > 0:
            high_recall_f1 = [f1_scores[i] for i in high_recall_idx if i < len(f1_scores)]
            if high_recall_f1:
                best_high_recall_idx = high_recall_idx[np.argmax(high_recall_f1)]
                sensitivity_threshold = thresholds[best_high_recall_idx] if best_high_recall_idx < len(thresholds) else 0.3
            else:
                sensitivity_threshold = 0.3
        else:
            sensitivity_threshold = 0.3
    else:
        best_threshold = 0.5
        best_f1 = 0.0
        sensitivity_threshold = 0.3

    return best_threshold, sensitivity_threshold

def evaluate_model_at_threshold(y_true, y_pred, threshold=0.5):
    """Calculate performance metrics at a given threshold"""
    # Convert probabilities to binary predictions
    y_pred_binary = (y_pred > threshold).astype(int)

    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred_binary)
    precision = precision_score(y_true, y_pred_binary, zero_division=0)  # PPV
    recall = recall_score(y_true, y_pred_binary, zero_division=0)  # Sensitivity
    f1 = f1_score(y_true, y_pred_binary, zero_division=0)

    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred_binary)
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)

    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # Specificity
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative predictive value

    print(f"Model Performance Metrics at threshold {threshold:.4f}:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision/PPV: {precision:.4f}")
    print(f"Recall/Sensitivity: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"NPV: {npv:.4f}")
    print(f"Confusion Matrix:")
    print(f"TN: {tn}, FP: {fp}")
    print(f"FN: {fn}, TP: {tp}")

    return {
        'threshold': threshold,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'npv': npv,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'tp': tp
    }

def main(data_path, debug=True, n_epochs=50):
    print("Starting Enhanced AKI Prediction with Improved BiLSTM...")

    # 1. Load and preprocess data
    data, cont_features = load_and_preprocess_data(data_path, debug=debug)

    # 2. Create enhanced sequences with clinical features (optimized version)
    sequences, static_features, labels, stay_ids, dynamic_features, static_features_names = create_enhanced_sequences_optimized(
        data, cont_features, debug=debug
    )

    # Check if we have sequences
    if not sequences:
        print("Error: No sequences were created from the data. Check your dataset.")
        return None, None, None

    # 3. Split data
    X_seq_train, X_seq_test, X_static_train, X_static_test, y_train, y_test, ids_train, ids_test = train_test_split(
        sequences, static_features, labels, stay_ids, test_size=0.2, random_state=SEED, stratify=labels
    )

    # Further split training into train and validation
    X_seq_train, X_seq_val, X_static_train, X_static_val, y_train, y_val, ids_train, ids_val = train_test_split(
        X_seq_train, X_static_train, y_train, ids_train, test_size=0.2, random_state=SEED, stratify=y_train
    )

    if debug:
        print("\nData split:")
        print(f"Training samples: {len(y_train)}")
        print(f"Validation samples: {len(y_val)}")
        print(f"Testing samples: {len(y_test)}")
        print(f"Label distribution in train: {np.bincount(np.array(y_train).astype(int))}")
        print(f"Label distribution in val: {np.bincount(np.array(y_val).astype(int))}")
        print(f"Label distribution in test: {np.bincount(np.array(y_test).astype(int))}")

    # 4. Create datasets and dataloaders
    train_dataset = AKISequenceDataset(X_seq_train, X_static_train, y_train)
    val_dataset = AKISequenceDataset(X_seq_val, X_static_val, y_val)
    test_dataset = AKISequenceDataset(X_seq_test, X_static_test, y_test)

    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, drop_last=False
    )
    val_loader = DataLoader(
        val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=False
    )
    test_loader = DataLoader(
        test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=False
    )

    # 5. Initialize improved model
    # Calculate input dimensions
    dynamic_dim = X_seq_train[0].shape[1]  # Number of dynamic features
    static_dim = X_static_train[0].shape[0]  # Number of static features

    # Create improved BiLSTM model with more stable initialization
    model = ImprovedBiLSTM(
        dynamic_dim=dynamic_dim,
        static_dim=static_dim,
        hidden_dim=128,  # More conservative size for stability
        n_layers=2,      # More conservative depth for stability
        dropout=0.3
    ).to(device)

    if debug:
        print("\nImproved Model details:")
        print(f"Dynamic features dimension: {dynamic_dim}")
        print(f"Static features dimension: {static_dim}")
        print(model)

    # 6. Set up training components
    # Convert to numpy arrays
    y_train_np = np.array(y_train).astype(int)

    # Calculate weighted loss
    pos_count = np.sum(y_train_np == 1)
    neg_count = np.sum(y_train_np == 0)
    weight_ratio = neg_count / max(pos_count, 1)  # Avoid division by zero

    # Use BCEWithLogitsLoss for stability
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([weight_ratio]).to(device))

    # Print the class weights
    print(f"Class weight ratio (neg/pos): {weight_ratio:.2f}")

    # AdamW optimizer with conservative learning rate
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=1e-5)

    # Use more stable learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # 7. Train improved model - 50 epochs as requested
    model, history = train_enhanced_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        n_epochs=n_epochs,  # Set to 50 as requested
        patience=15,
        device=device,
        debug=debug,
        save_dir='checkpoints'
    )

    # 8. Plot training history
    plt.figure(figsize=(15, 5))

    # Plot losses
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # Plot AUC
    plt.subplot(1, 3, 2)
    plt.plot(history['train_auc'], label='Train AUC')
    plt.plot(history['val_auc'], label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title('Training and Validation AUC')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('results/training_history.png')
    plt.close()

    # 9. Evaluate model
    print("\nEvaluating model...")
    test_preds, test_labels, test_attentions = evaluate_model(model, test_loader, device)

    # 10. Find optimal thresholds
    balanced_threshold, sensitivity_threshold = find_optimal_threshold(test_labels, test_preds)

    # 11. Plot ROC Curve
    try:
        fpr, tpr, _ = roc_curve(test_labels, test_preds)
        test_auc = roc_auc_score(test_labels, test_preds)

        plt.figure(figsize=(10, 8))
        plt.plot(fpr, tpr, label=f'AUC = {test_auc:.4f}')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend(loc='lower right')
        plt.grid(True)

        # Add markers for thresholds
        try:
            balanced_idx = np.argmin(np.abs(fpr - (1-balanced_threshold)))
            sensitivity_idx = np.argmin(np.abs(tpr - 0.65))

            plt.plot(fpr[balanced_idx], tpr[balanced_idx], 'ro', markersize=10,
                     label=f'Balanced Threshold = {balanced_threshold:.4f}')
            plt.plot(fpr[sensitivity_idx], tpr[sensitivity_idx], 'go', markersize=10,
                     label=f'Sensitivity Threshold = {sensitivity_threshold:.4f}')
            plt.legend()
        except Exception as e:
            print(f"Error adding threshold markers: {e}")

        plt.savefig('results/roc_curve.png')
        plt.close()
    except Exception as e:
        print(f"Error plotting ROC curve: {e}")
        test_auc = 0.5

    # 12. Calculate and display metrics for both thresholds
    print("\nBalanced Threshold Metrics (Maximizing F1):")
    balanced_metrics = evaluate_model_at_threshold(test_labels, test_preds, threshold=balanced_threshold)

    print("\nSensitivity-Optimized Threshold Metrics:")
    sensitivity_metrics = evaluate_model_at_threshold(test_labels, test_preds, threshold=sensitivity_threshold)

    # 13. Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'test_auc': test_auc,
        'balanced_threshold': balanced_threshold,
        'sensitivity_threshold': sensitivity_threshold,
        'balanced_metrics': balanced_metrics,
        'sensitivity_metrics': sensitivity_metrics,
        'dynamic_dim': dynamic_dim,
        'static_dim': static_dim
    }, 'final_improved_model.pt')
    print("Final model saved as 'final_improved_model.pt'")

    return model, history, test_auc

if __name__ == "__main__":
    # Set your data path
    data_path = "latest.csv"

    # Run the pipeline - train for 50 epochs as requested
    model, history, auc = main(data_path, debug=True, n_epochs=50)

Using device: cuda
Starting Enhanced AKI Prediction with Improved BiLSTM...
Loading data from latest.csv
Data shape: (485580, 48)
Columns: ['subject_id', 'stay_id', 'hour_from_icu', 'creatinine', 'bicarbonate', 'chloride', 'glucose', 'magnesium', 'potassium', 'sodium', 'urea_nitrogen', 'hemoglobin', 'platelet_count', 'wbc_count', 'lactate', 'paco2', 'ph', 'pao2', 'albumin', 'anion_gap', 'hematocrit', 'inr', 'pt', 'aptt', 'aki_48h', 'aki_7day', 'aki_label', 'heart_rate', 'resp_rate', 'temperature', 'spo2', 'nbp_sys', 'nbp_dias', 'nbp_mean', 'gcs_total', 'urine_or', 'urine_pacu', 'urine_output_ml', 'intime', 'gender', 'age', 'race', 'admission_type', 'weight', 'height', 'baseline_creatinine', 'time_window', 'time_window_str']

First few rows:
   subject_id   stay_id  hour_from_icu  creatinine  bicarbonate  chloride  \
0    10000032  39553978              7         0.5         21.0     102.0   
1    10000032  39553978             16         0.4         24.0     102.0   
2    10000032  395

Creating patient sequences:   0%|          | 0/29446 [00:00<?, ?it/s]


Created 28999 sequences
Dynamic feature dimension: (12, 53)
Static feature dimension: (5,)
Label distribution: [26158  2841]

Data split:
Training samples: 18559
Validation samples: 4640
Testing samples: 5800
Label distribution in train: [16741  1818]
Label distribution in val: [4185  455]
Label distribution in test: [5232  568]

Improved Model details:
Dynamic features dimension: 53
Static features dimension: 5
ImprovedBiLSTM(
  (lstm): LSTM(53, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (attention): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
  (static_encoder): Sequential(
    (0): Linear(in_features=5, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.3, inplace=False)
  )
  (clinical_encoder): Sequential(
    (0): Linear(in_features=3, out

Epoch 1/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 1/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 1/50 Results:
  Train Loss: 1.1498
  Train AUC:  0.6968
  Valid Loss: 1.1226
  Valid AUC:  0.7283
------------------------------------------------------------
Best model saved with Val AUC: 0.7283, Val Loss: 1.1226


Epoch 2/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 2/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 2/50 Results:
  Train Loss: 1.1299
  Train AUC:  0.7139
  Valid Loss: 1.1139
  Valid AUC:  0.7304
------------------------------------------------------------
Best model saved with Val AUC: 0.7304, Val Loss: 1.1139


Epoch 3/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 3/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 3/50 Results:
  Train Loss: 1.1299
  Train AUC:  0.7223
  Valid Loss: 1.1605
  Valid AUC:  0.7340
------------------------------------------------------------
Best model saved with Val AUC: 0.7340, Val Loss: 1.1605


Epoch 4/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 4/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 4/50 Results:
  Train Loss: 1.1513
  Train AUC:  0.7164
  Valid Loss: 1.1379
  Valid AUC:  0.7388
------------------------------------------------------------
Best model saved with Val AUC: 0.7388, Val Loss: 1.1379


Epoch 5/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 5/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 5/50 Results:
  Train Loss: 1.1417
  Train AUC:  0.7266
  Valid Loss: 1.1835
  Valid AUC:  0.7404
------------------------------------------------------------
Best model saved with Val AUC: 0.7404, Val Loss: 1.1835


Epoch 6/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 6/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 6/50 Results:
  Train Loss: 1.1360
  Train AUC:  0.7319
  Valid Loss: 1.2043
  Valid AUC:  0.7332
------------------------------------------------------------


Epoch 7/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 7/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 7/50 Results:
  Train Loss: 1.1306
  Train AUC:  0.7383
  Valid Loss: 1.1328
  Valid AUC:  0.7451
------------------------------------------------------------
Best model saved with Val AUC: 0.7451, Val Loss: 1.1328


Epoch 8/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 8/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 8/50 Results:
  Train Loss: 1.1374
  Train AUC:  0.7362
  Valid Loss: 1.0779
  Valid AUC:  0.7489
------------------------------------------------------------
Best model saved with Val AUC: 0.7489, Val Loss: 1.0779


Epoch 9/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 9/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 9/50 Results:
  Train Loss: 1.1290
  Train AUC:  0.7399
  Valid Loss: 1.0998
  Valid AUC:  0.7577
------------------------------------------------------------
Best model saved with Val AUC: 0.7577, Val Loss: 1.0998


Epoch 10/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 10/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 10/50 Results:
  Train Loss: 1.1270
  Train AUC:  0.7432
  Valid Loss: 1.1033
  Valid AUC:  0.7447
------------------------------------------------------------


Epoch 11/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 11/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 11/50 Results:
  Train Loss: 1.1238
  Train AUC:  0.7468
  Valid Loss: 1.1376
  Valid AUC:  0.7539
------------------------------------------------------------


Epoch 12/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 12/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 12/50 Results:
  Train Loss: 1.1315
  Train AUC:  0.7431
  Valid Loss: 1.1648
  Valid AUC:  0.7344
------------------------------------------------------------


Epoch 13/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 13/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 13/50 Results:
  Train Loss: 1.1395
  Train AUC:  0.7453
  Valid Loss: 1.0680
  Valid AUC:  0.7551
------------------------------------------------------------


Epoch 14/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 14/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 14/50 Results:
  Train Loss: 1.1187
  Train AUC:  0.7504
  Valid Loss: 1.1147
  Valid AUC:  0.7578
------------------------------------------------------------
Best model saved with Val AUC: 0.7578, Val Loss: 1.1147


Epoch 15/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 15/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 15/50 Results:
  Train Loss: 1.1191
  Train AUC:  0.7520
  Valid Loss: 1.0799
  Valid AUC:  0.7657
------------------------------------------------------------
Best model saved with Val AUC: 0.7657, Val Loss: 1.0799


Epoch 16/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 16/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 16/50 Results:
  Train Loss: 1.1257
  Train AUC:  0.7501
  Valid Loss: 1.0697
  Valid AUC:  0.7650
------------------------------------------------------------


Epoch 17/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 17/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 17/50 Results:
  Train Loss: 1.1172
  Train AUC:  0.7560
  Valid Loss: 1.2048
  Valid AUC:  0.7444
------------------------------------------------------------


Epoch 18/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 18/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 18/50 Results:
  Train Loss: 1.1345
  Train AUC:  0.7493
  Valid Loss: 1.1752
  Valid AUC:  0.7312
------------------------------------------------------------


Epoch 19/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 19/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 19/50 Results:
  Train Loss: 1.1279
  Train AUC:  0.7499
  Valid Loss: 1.1569
  Valid AUC:  0.7424
------------------------------------------------------------


Epoch 20/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 20/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 20/50 Results:
  Train Loss: 1.0962
  Train AUC:  0.7606
  Valid Loss: 1.0769
  Valid AUC:  0.7621
------------------------------------------------------------


Epoch 21/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 21/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 21/50 Results:
  Train Loss: 1.0881
  Train AUC:  0.7661
  Valid Loss: 1.1135
  Valid AUC:  0.7575
------------------------------------------------------------


Epoch 22/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 22/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 22/50 Results:
  Train Loss: 1.0964
  Train AUC:  0.7654
  Valid Loss: 1.1417
  Valid AUC:  0.7688
------------------------------------------------------------
Best model saved with Val AUC: 0.7688, Val Loss: 1.1417


Epoch 23/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 23/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 23/50 Results:
  Train Loss: 1.0885
  Train AUC:  0.7702
  Valid Loss: 1.0803
  Valid AUC:  0.7589
------------------------------------------------------------


Epoch 24/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 24/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 24/50 Results:
  Train Loss: 1.0858
  Train AUC:  0.7690
  Valid Loss: 1.1523
  Valid AUC:  0.7711
------------------------------------------------------------
Best model saved with Val AUC: 0.7711, Val Loss: 1.1523


Epoch 25/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 25/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 25/50 Results:
  Train Loss: 1.0916
  Train AUC:  0.7656
  Valid Loss: 1.0911
  Valid AUC:  0.7737
------------------------------------------------------------
Best model saved with Val AUC: 0.7737, Val Loss: 1.0911


Epoch 26/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 26/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 26/50 Results:
  Train Loss: 1.0755
  Train AUC:  0.7738
  Valid Loss: 1.1522
  Valid AUC:  0.7584
------------------------------------------------------------


Epoch 27/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 27/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 27/50 Results:
  Train Loss: 1.0661
  Train AUC:  0.7770
  Valid Loss: 1.0799
  Valid AUC:  0.7736
------------------------------------------------------------


Epoch 28/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 28/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 28/50 Results:
  Train Loss: 1.0714
  Train AUC:  0.7754
  Valid Loss: 1.0623
  Valid AUC:  0.7711
------------------------------------------------------------


Epoch 29/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 29/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 29/50 Results:
  Train Loss: 1.0657
  Train AUC:  0.7771
  Valid Loss: 1.1004
  Valid AUC:  0.7692
------------------------------------------------------------


Epoch 30/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 30/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 30/50 Results:
  Train Loss: 1.0618
  Train AUC:  0.7786
  Valid Loss: 1.1192
  Valid AUC:  0.7744
------------------------------------------------------------
Best model saved with Val AUC: 0.7744, Val Loss: 1.1192


Epoch 31/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 31/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 31/50 Results:
  Train Loss: 1.0597
  Train AUC:  0.7780
  Valid Loss: 1.0507
  Valid AUC:  0.7763
------------------------------------------------------------
Best model saved with Val AUC: 0.7763, Val Loss: 1.0507


Epoch 32/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 32/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 32/50 Results:
  Train Loss: 1.0531
  Train AUC:  0.7809
  Valid Loss: 1.1037
  Valid AUC:  0.7716
------------------------------------------------------------


Epoch 33/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 33/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 33/50 Results:
  Train Loss: 1.0451
  Train AUC:  0.7821
  Valid Loss: 1.0682
  Valid AUC:  0.7744
------------------------------------------------------------


Epoch 34/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 34/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 34/50 Results:
  Train Loss: 1.0470
  Train AUC:  0.7822
  Valid Loss: 1.0795
  Valid AUC:  0.7732
------------------------------------------------------------


Epoch 35/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 35/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 35/50 Results:
  Train Loss: 1.0489
  Train AUC:  0.7809
  Valid Loss: 1.1391
  Valid AUC:  0.7709
------------------------------------------------------------


Epoch 36/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 36/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 36/50 Results:
  Train Loss: 1.0551
  Train AUC:  0.7777
  Valid Loss: 1.0893
  Valid AUC:  0.7715
------------------------------------------------------------


Epoch 37/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 37/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 37/50 Results:
  Train Loss: 1.0526
  Train AUC:  0.7798
  Valid Loss: 1.0619
  Valid AUC:  0.7755
------------------------------------------------------------


Epoch 38/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 38/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 38/50 Results:
  Train Loss: 1.0356
  Train AUC:  0.7860
  Valid Loss: 1.0733
  Valid AUC:  0.7763
------------------------------------------------------------


Epoch 39/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 39/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 39/50 Results:
  Train Loss: 1.0331
  Train AUC:  0.7852
  Valid Loss: 1.0852
  Valid AUC:  0.7746
------------------------------------------------------------


Epoch 40/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 40/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 40/50 Results:
  Train Loss: 1.0396
  Train AUC:  0.7841
  Valid Loss: 1.0959
  Valid AUC:  0.7719
------------------------------------------------------------


Epoch 41/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 41/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 41/50 Results:
  Train Loss: 1.0326
  Train AUC:  0.7857
  Valid Loss: 1.1003
  Valid AUC:  0.7717
------------------------------------------------------------


Epoch 42/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 42/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 42/50 Results:
  Train Loss: 1.0286
  Train AUC:  0.7882
  Valid Loss: 1.0815
  Valid AUC:  0.7751
------------------------------------------------------------


Epoch 43/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 43/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 43/50 Results:
  Train Loss: 1.0354
  Train AUC:  0.7849
  Valid Loss: 1.0683
  Valid AUC:  0.7749
------------------------------------------------------------


Epoch 44/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 44/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 44/50 Results:
  Train Loss: 1.0261
  Train AUC:  0.7897
  Valid Loss: 1.0461
  Valid AUC:  0.7775
------------------------------------------------------------
Best model saved with Val AUC: 0.7775, Val Loss: 1.0461


Epoch 45/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 45/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 45/50 Results:
  Train Loss: 1.0253
  Train AUC:  0.7888
  Valid Loss: 1.0826
  Valid AUC:  0.7746
------------------------------------------------------------


Epoch 46/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 46/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 46/50 Results:
  Train Loss: 1.0286
  Train AUC:  0.7864
  Valid Loss: 1.0519
  Valid AUC:  0.7750
------------------------------------------------------------


Epoch 47/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 47/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 47/50 Results:
  Train Loss: 1.0279
  Train AUC:  0.7891
  Valid Loss: 1.0795
  Valid AUC:  0.7737
------------------------------------------------------------


Epoch 48/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 48/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 48/50 Results:
  Train Loss: 1.0371
  Train AUC:  0.7839
  Valid Loss: 1.0751
  Valid AUC:  0.7793
------------------------------------------------------------
Best model saved with Val AUC: 0.7793, Val Loss: 1.0751


Epoch 49/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 49/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 49/50 Results:
  Train Loss: 1.0298
  Train AUC:  0.7878
  Valid Loss: 1.0776
  Valid AUC:  0.7739
------------------------------------------------------------


Epoch 50/50 [Train]:   0%|          | 0/580 [00:00<?, ?it/s]

Epoch 50/50 [Valid]:   0%|          | 0/145 [00:00<?, ?it/s]

Epoch 50/50 Results:
  Train Loss: 1.0123
  Train AUC:  0.7945
  Valid Loss: 1.0600
  Valid AUC:  0.7761
------------------------------------------------------------

Evaluating model...


Evaluating model:   0%|          | 0/182 [00:00<?, ?it/s]


Balanced Threshold Metrics (Maximizing F1):
Model Performance Metrics at threshold 0.6304:
Accuracy: 0.8821
Precision/PPV: 0.4014
Recall/Sensitivity: 0.4155
F1 Score: 0.4083
Specificity: 0.9327
NPV: 0.9363
Confusion Matrix:
TN: 4880, FP: 352
FN: 332, TP: 236

Sensitivity-Optimized Threshold Metrics:
Model Performance Metrics at threshold 0.3975:
Accuracy: 0.7703
Precision/PPV: 0.2457
Recall/Sensitivity: 0.6496
F1 Score: 0.3565
Specificity: 0.7834
NPV: 0.9537
Confusion Matrix:
TN: 4099, FP: 1133
FN: 199, TP: 369
Final model saved as 'final_improved_model.pt'
