In [None]:
from google.colab import drive, runtime

# Mount Google Drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os
import time
from datetime import datetime
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from sklearn.metrics import precision_recall_curve, auc, recall_score, precision_score, f1_score
# import wandb
from tqdm import tqdm
import gc
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

# Add autoreload support
%load_ext autoreload
%autoreload 2


In [None]:
import sys
parent_dir = '/content/drive/My Drive/Colab Notebooks/wildfire'
sys.path.append(parent_dir)

# Change to the working directory
os.chdir(os.path.join(parent_dir, 'v4'))
!pwd

/content/drive/MyDrive/Colab Notebooks/wildfire/v4


In [None]:
from v4.model.dataset_v4 import create_dataloaders, create_test_loader, create_baseline_dataloaders, create_baseline_test_loader, create_multi_target_dataloaders, create_multi_target_test_loader, create_multi_target_baseline_dataloaders, create_multi_target_baseline_test_loader
from v4.model.models_v4 import FireTransformer, BaselineModels, NNBaselineModels
from v4.model.loss_v4 import improved_weighted_focal_loss_v1, improved_weighted_focal_loss_v2, weighted_focal_loss
from v4.model.utils import calculate_metrics

In [None]:
NEED_TEST_DATA = True
# Global Configuration
BASE_PATH = '/content/drive/My Drive/Colab Notebooks/wildfire'
TRAIN_DATA_PATH = os.path.join(BASE_PATH, 'new_data/train/sequences_y2019-2024_w15_o0_r4.npz')
VAL_DATA_PATH = None
TEST_DATA_PATH = os.path.join(BASE_PATH, 'new_data/test_detect/sequences_y2025-2025_w15_o0_r4.npz')

model_type = 'transformer'

# Model Parameters
TASK='predict'
N_LAYERS = 4
D_FF = 256  # Add feed-forward dimension as hyperparameter
TARGET_COL = 'targetY_o1_prob'  # New hyperparameter
DROPOUT = 0.4
TRAIN_RATIO=0.97
USE_FEATURE_NORM = True
USE_WARMUP_SCHEDULER = False
BASE_POS_WEIGHT = 4
FN_PENALTY = 30.0
FP_PENALTY = 3
MAX_NORM=1.0
CONFIDENCE_MARGIN = 0.1
CONFIDENCE_WEIGHT = 5.0
EPOCHS = 20

D_MODEL = 128
N_HEADS = 8
# BATCH_SIZE = 2048
BATCH_SIZE = 5120
LEARNING_RATE = 0.0001 * (BATCH_SIZE / 512)
LOG_INTERVAL = 100
EVAL_ADD_NOISE = False
EVAL_NOISE_STD = 1e-3
IS_PROB_TARGET = '_prob' in TARGET_COL  # Automatically set based on target column
FOCAL_GAMMA = 0

# Add new hyperparameter
LOSS_VERSION = 'v1'  # Options: 'v1' for original, 'v2' for new anti-clustering version

# Extract target_type from TARGET_COL
if TARGET_COL.endswith('_o1_prob'):
    TARGET_TYPE = 'o1'
elif TARGET_COL.endswith('_o2_prob'):
    TARGET_TYPE = 'o2'
elif TARGET_COL.endswith('_o3_prob'):
    TARGET_TYPE = 'o3'
else:
    TARGET_TYPE = 'basic'  # Default for binary targets

# Create model name with hyperparameters
TEST_OUTPUT_DIR = os.path.join(BASE_PATH, 'new_data/test_detect/predictions')

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Create model name with hyperparameters
if model_type == 'transformer':
    MODEL_NAME = f"transformer_{TASK}_{model_type}_L{N_LAYERS}_D{D_MODEL}_H{N_HEADS}_G{FOCAL_GAMMA}_PW{BASE_POS_WEIGHT}_FN{FN_PENALTY}_FP{FP_PENALTY}_T{TARGET_COL}_DR{DROPOUT}_BN{USE_FEATURE_NORM}_WU_{USE_WARMUP_SCHEDULER}_LOSS{LOSS_VERSION}"
elif model_type == 'traditional':
    MODEL_NAME = f"TRAD_PW{BASE_POS_WEIGHT}_FN{FN_PENALTY}_FP{FP_PENALTY}_{timestamp}"
elif model_type == 'nn':
    MODEL_NAME = f"NN_PW{BASE_POS_WEIGHT}_FN{FN_PENALTY}_FP{FP_PENALTY}_E{NN_EPOCHS}_LR{NN_LEARNING_RATE}_DO{NN_DROPOUT}_{timestamp}"
elif model_type == 'all':
    MODEL_NAME = f"ALL_PW{BASE_POS_WEIGHT}_FN{FN_PENALTY}_FP{FP_PENALTY}_E{NN_EPOCHS}_LR{NN_LEARNING_RATE}_DO{NN_DROPOUT}_{timestamp}"
else:
    MODEL_NAME = f"UNKNOWN_{timestamp}"

LOG_DIR = os.path.join(BASE_PATH, 'runs', timestamp + '_' + MODEL_NAME)

# Create necessary directories
os.makedirs(LOG_DIR, exist_ok=True)

In [None]:
def plot_training_history(losses, val_metrics, save_path):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot training loss
    ax1.plot(losses)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # Plot validation metrics
    for metric in ['auc_pr', 'accuracy', 'recall', 'precision']:
        values = [metrics[metric] for metrics in val_metrics]
        ax2.plot(values, label=metric)
    ax2.set_title('Validation Metrics')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Score')
    ax2.legend()

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def evaluate(model, dataloader, device, add_noise=False, noise_std=1e-3, is_prob_target=False):
    """Evaluate model on dataloader"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits)
            all_preds.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Log prediction distribution around 0.5
    exact_half = np.mean(all_preds == 0.5) * 100
    near_half_range = 0.01
    near_half = np.mean((all_preds > 0.5 - near_half_range) &
                       (all_preds < 0.5 + near_half_range)) * 100

    print(f"\nPrediction distribution analysis:")
    print(f"Predictions exactly 0.5: {exact_half:.2f}%")
    print(f"Predictions within ±{near_half_range} of 0.5: {near_half:.2f}%")

    # Optional prediction stabilization
    if add_noise and exact_half > 0:
        print(f"Adding noise (std={noise_std}) to break symmetry")
        all_preds = all_preds + np.random.normal(0, noise_std, all_preds.shape)
        if np.all(all_preds == 0.5):
            print("WARNING: All predictions still exactly 0.5 after noise, forcing split")
            all_preds = np.random.choice([0.49, 0.51], size=all_preds.shape)

    return calculate_metrics(all_preds, all_labels, is_prob_target=is_prob_target)

def analyze_training_history(runs_dir):
    """Analyze and visualize training history from multiple runs"""
    from torch.utils.tensorboard.backend.event_processing.event_accumulator import EventAccumulator

    runs = []
    metrics = ['val/auc_pr', 'val/accuracy', 'val/recall', 'val/precision', 'val/f1']

    # Collect data from all runs
    for run_dir in os.listdir(runs_dir):
        run_path = os.path.join(runs_dir, run_dir)
        if not os.path.isdir(run_path):
            continue

        event_path = None
        for root, _, files in os.walk(run_path):
            for file in files:
                if file.startswith('events.out.tfevents'):
                    event_path = os.path.join(root, file)
                    break

        if event_path:
            ea = EventAccumulator(event_path)
            ea.Reload()

            # Extract hyperparameters from directory name
            params = {}
            for param in run_dir.split('_'):
                if any(param.startswith(p) for p in ['L', 'D', 'H', 'G', 'PW', 'FN', 'FP']):
                    key = param[0:2]
                    value = float(param[2:])
                    params[key] = value

            # Collect metrics
            run_data = {'params': params, 'metrics': {}}
            for metric in metrics:
                if metric in ea.scalars.Keys():
                    events = ea.Scalars(metric)
                    run_data['metrics'][metric] = {
                        'steps': [e.step for e in events],
                        'values': [e.value for e in events]
                    }

            runs.append(run_data)

    # Plot metrics
    plt.figure(figsize=(20, 15))
    for i, metric in enumerate(metrics, 1):
        plt.subplot(3, 2, i)
        for run in runs:
            if metric in run['metrics']:
                label = '_'.join(f"{k}{v}" for k, v in run['params'].items())
                plt.plot(
                    run['metrics'][metric]['steps'],
                    run['metrics'][metric]['values'],
                    label=label
                )
        plt.title(metric.split('/')[-1].upper())
        plt.xlabel('Steps')
        plt.ylabel('Value')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(runs_dir, 'comparison.png'), bbox_inches='tight')
    plt.close()

def train_transformer():
    """Train Transformer model"""
    start_time = time.time()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    writer = SummaryWriter(LOG_DIR)

    # Early stopping parameters
    early_stopping_patience = 10
    early_stopping_counter = 0
    best_val_auc = 0

    # Create dataloaders with validation data path
    print("\nCreating dataloaders...")
    dataloader_start = time.time()
    train_loader, val_loader = create_multi_target_dataloaders(
        TRAIN_DATA_PATH,
        target_col=TARGET_COL,
        batch_size=BATCH_SIZE,
        train_ratio=TRAIN_RATIO,
        val_data_path=VAL_DATA_PATH
    )
    print(f"Dataloader creation took {time.time() - dataloader_start:.2f} seconds")

    # Check class distribution
    train_labels = np.concatenate([y.numpy() for _, y in train_loader])
    val_labels = np.concatenate([y.numpy() for _, y in val_loader])
    print(f"Training class distribution: {np.bincount(train_labels.astype(int))}")
    print(f"Validation class distribution: {np.bincount(val_labels.astype(int))}")

    # Create model
    print("\nInitializing model...")
    model_start = time.time()
    feature_size = next(iter(train_loader))[0].shape[-1]
    window_size = next(iter(train_loader))[0].shape[1]  # Get window size from data
    model = FireTransformer(
        feature_size=feature_size,
        window_size=window_size,  # Pass window size from data
        n_layers=N_LAYERS,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        dropout=DROPOUT,  # Add dropout parameter
        use_feature_norm=USE_FEATURE_NORM
    ).to(device)
    print(f"Model initialization took {time.time() - model_start:.2f} seconds")

    # Setup training
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=0.01,  # Add weight decay for larger batch size
        betas=(0.9, 0.999)  # Default Adam betas usually work well
    )

    if USE_WARMUP_SCHEDULER:
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                            max_lr=LEARNING_RATE,
                                            epochs=EPOCHS,
                                            steps_per_epoch=len(train_loader),
                                            pct_start=0.2,  # 20% warmup
                                            div_factor=3.0,
                                            final_div_factor=100)
    else:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=5,  # Increase patience for large batch size
            min_lr=1e-6,
            threshold=1e-4
        )

    # Select loss function based on version
    if LOSS_VERSION == 'v2':
        loss_fn = lambda y_pred, y_true: improved_weighted_focal_loss_v2(
            y_pred, y_true,
            base_pos_weight=BASE_POS_WEIGHT,
            fn_penalty=FN_PENALTY,
            fp_penalty=FP_PENALTY,
            is_prob_target=IS_PROB_TARGET,
            target_type=TARGET_TYPE,
            confidence_margin=CONFIDENCE_MARGIN,
            confidence_weight=CONFIDENCE_WEIGHT
        )
    else:  # 'v1' or default
        loss_fn = lambda y_pred, y_true: improved_weighted_focal_loss_v1(
            y_pred, y_true,
            base_pos_weight=BASE_POS_WEIGHT,
            fn_penalty=FN_PENALTY,
            fp_penalty=FP_PENALTY,
            is_prob_target=IS_PROB_TARGET
        )

    # Training loop
    best_val_auc = 0
    losses = []
    val_metrics_history = []

    print("\nStarting training loop...")
    for epoch in range(EPOCHS):
        epoch_start = time.time()
        model.train()
        total_loss = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}')

        # In training loop
        running_fnr = []
        fnr_window = 100

        try:
            for i, (X, y) in enumerate(pbar):
                # Move data to device at the start
                X = X.to(device)
                y = y.to(device)

                # Clear gradients at start of forward pass
                optimizer.zero_grad()

                # Forward pass
                logits = model(X)

                # Add noise to logits
                logits = logits + torch.randn_like(logits) * 0.01

                # Calculate loss
                loss = loss_fn(logits, y)  # y is already on device

                # Backward pass
                loss.backward()

                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=MAX_NORM)

                # Update weights and learning rate
                optimizer.step()
                if USE_WARMUP_SCHEDULER:
                    scheduler.step()  # OneCycleLR update
                total_loss += loss.item()
                current_loss = total_loss / (i + 1)

                pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                })

                # Debug: Check predictions
                if i % LOG_INTERVAL == 0:
                    with torch.no_grad():
                        probs = torch.sigmoid(logits)
                        pos_preds = (probs > 0.5).sum().item()
                        avg_prob = probs.mean().item()
                        pos_labels = y.sum().item()
                        batch_fnr = ((y == 1) & (probs < 0.5)).float().mean()

                        running_fnr.append(batch_fnr.item())
                        if len(running_fnr) > fnr_window:
                            running_fnr.pop(0)

                        avg_fnr = sum(running_fnr) / len(running_fnr)
                        print(f"Batch {i}: Pos Pred: {pos_preds}, Pos Labels: {pos_labels}, "
                              f"Avg Prob: {avg_prob:.4f}, Avg FNR = {avg_fnr:.4f}")

        except RuntimeError as e:
            if "out of memory" in str(e):
                print("WARNING: out of memory")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gc.collect()
                continue
            else:
                raise e

        # Record epoch loss
        epoch_loss = total_loss / len(train_loader)
        losses.append(epoch_loss)

        # Evaluate
        val_metrics = evaluate(model, val_loader, device,
                              add_noise=EVAL_ADD_NOISE,
                              noise_std=EVAL_NOISE_STD,
                              is_prob_target=IS_PROB_TARGET)
        val_metrics_history.append(val_metrics)

        # Log all metrics
        for metric_name, metric_value in val_metrics.items():
            writer.add_scalar(f'val/{metric_name}', metric_value, epoch)

        # NOTE: only for ReduceLROnPlateau
        if not USE_WARMUP_SCHEDULER:
            scheduler.step(val_metrics['false_negative_rate'])  # Monitor FNR instead of AUC-PR

        # Save best model with hyperparameters in name
        if val_metrics['auc_pr'] > best_val_auc:
            best_val_auc = val_metrics['auc_pr']
            model_save_path = os.path.join(LOG_DIR, f'best_{MODEL_NAME}.pt')
            torch.save({
                'model_state_dict': model.state_dict(),
                'hyperparameters': {
                    'n_layers': N_LAYERS,
                    'd_model': D_MODEL,
                    'n_heads': N_HEADS,
                    'focal_gamma': FOCAL_GAMMA,
                    'base_pos_weight': BASE_POS_WEIGHT,
                    'fn_penalty': FN_PENALTY,
                    'fp_penalty': FP_PENALTY,
                    'is_prob_target': IS_PROB_TARGET,
                    'target_col': TARGET_COL,
                    'dropout': DROPOUT,  # Add dropout to saved hyperparameters
                    'loss_version': LOSS_VERSION,
                    'target_type': TARGET_TYPE
                }
            }, model_save_path)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                break

        print(f"Epoch {epoch}: Loss = {epoch_loss:.4f}, "
              f"Val AUC-PR = {val_metrics['auc_pr']:.4f} (fixed: {val_metrics['fixed_auc_pr']:.4f}), "
              f"Val Accuracy = {val_metrics['accuracy']:.4f} (fixed: {val_metrics['fixed_accuracy']:.4f}), "
              f"Val Recall = {val_metrics['recall']:.4f} (fixed: {val_metrics['fixed_recall']:.4f}), "
              f"Val Precision = {val_metrics['precision']:.4f} (fixed: {val_metrics['fixed_precision']:.4f}), "
              f"Val F1 = {val_metrics['f1']:.4f} (fixed: {val_metrics['fixed_f1']:.4f}), "
              f"FNR = {val_metrics['false_negative_rate']:.4f} (fixed: {val_metrics['fixed_fnr']:.4f}), "
              f"FPR = {val_metrics['false_positive_rate']:.4f} (fixed: {val_metrics['fixed_fpr']:.4f}), "
              f"opt Threshold = {val_metrics['optimal_threshold']:.4f}, "
              f"TP = {val_metrics['true_positives']} (fixed: {val_metrics['fixed_true_positives']}), "
              f"TN = {val_metrics['true_negatives']} (fixed: {val_metrics['fixed_true_negatives']}), "
              f"FP = {val_metrics['false_positives']} (fixed: {val_metrics['fixed_false_positives']}), "
              f"FN = {val_metrics['false_negatives']} (fixed: {val_metrics['fixed_false_negatives']}), "
              f"Uncertain%: {val_metrics['uncertain_ratio']:.4f}, "
              f"Uncertain Acc: {val_metrics['uncertain_accuracy']:.4f}")

        # Add probability metrics only if they exist
        if 'mse' in val_metrics:
            print(f"MSE = {val_metrics['mse']:.4f}, "
                  f"MAE = {val_metrics['mae']:.4f}, "
                  f"RMSE = {val_metrics['rmse']:.4f}, "
                  f"Prob Correlation = {val_metrics['prob_correlation']:.4f}, "
                  f"Mean Pred Prob = {val_metrics['mean_pred_prob']:.4f}, "
                  f"Mean True Prob = {val_metrics['mean_true_prob']:.4f}")

        epoch_time = time.time() - epoch_start
        print(f"Total epoch time: {epoch_time:.2f}s")

    total_time = time.time() - start_time
    print(f"\nTotal training time: {total_time:.2f} seconds")

    # Plot training history
    plot_training_history(
        losses,
        val_metrics_history,
        os.path.join(LOG_DIR, 'training_history.png')
    )

    if NEED_TEST_DATA:
        # Create test loader
        test_loader = create_multi_target_test_loader(
            TEST_DATA_PATH,
            target_col=TARGET_COL,  # Use the same target column as training
            batch_size=BATCH_SIZE
        )

        # After training completes, evaluate on test set
        print("\nEvaluating best model on test set...")
        # Load best model
        checkpoint = torch.load(os.path.join(LOG_DIR, f'best_{MODEL_NAME}.pt'))
        model.load_state_dict(checkpoint['model_state_dict'])

        # Use existing evaluate function to get predictions and metrics
        test_metrics = evaluate(model, test_loader, device,
                              add_noise=EVAL_ADD_NOISE,
                              noise_std=EVAL_NOISE_STD,
                              is_prob_target=IS_PROB_TARGET)

        # Print metrics in detailed format
        print("\nTest Set Metrics:")
        print(f"AUC-PR = {test_metrics['auc_pr']:.4f} (fixed: {test_metrics['fixed_auc_pr']:.4f})")
        print(f"Accuracy = {test_metrics['accuracy']:.4f} (fixed: {test_metrics['fixed_accuracy']:.4f})")
        print(f"Recall = {test_metrics['recall']:.4f} (fixed: {test_metrics['fixed_recall']:.4f})")
        print(f"Precision = {test_metrics['precision']:.4f} (fixed: {test_metrics['fixed_precision']:.4f})")
        print(f"F1 = {test_metrics['f1']:.4f} (fixed: {test_metrics['fixed_f1']:.4f})")
        print(f"FNR = {test_metrics['false_negative_rate']:.4f} (fixed: {test_metrics['fixed_fnr']:.4f})")
        print(f"FPR = {test_metrics['false_positive_rate']:.4f} (fixed: {test_metrics['fixed_fpr']:.4f})")
        print(f"opt Threshold = {test_metrics['optimal_threshold']:.4f}")
        print(f"TP = {test_metrics['true_positives']} (fixed: {test_metrics['fixed_true_positives']})")
        print(f"TN = {test_metrics['true_negatives']} (fixed: {test_metrics['fixed_true_negatives']})")
        print(f"FP = {test_metrics['false_positives']} (fixed: {test_metrics['fixed_false_positives']})")
        print(f"FN = {test_metrics['false_negatives']} (fixed: {test_metrics['fixed_false_negatives']})")
        print(f"Uncertain%: {test_metrics['uncertain_ratio']:.4f}")
        print(f"Uncertain Acc: {test_metrics['uncertain_accuracy']:.4f}")

        # Print probability metrics if available
        if IS_PROB_TARGET and 'mse' in test_metrics:
            print(f"\nProbability Metrics:")
            print(f"MSE = {test_metrics['mse']:.4f}")
            print(f"MAE = {test_metrics['mae']:.4f}")
            print(f"RMSE = {test_metrics['rmse']:.4f}")
            print(f"Prob Correlation = {test_metrics['prob_correlation']:.4f}")
            print(f"Mean Pred Prob = {test_metrics['mean_pred_prob']:.4f}")
            print(f"Mean True Prob = {test_metrics['mean_true_prob']:.4f}")

        # Save predictions and true values
        if not os.path.exists(TEST_OUTPUT_DIR):
            os.makedirs(TEST_OUTPUT_DIR)

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_path = os.path.join(TEST_OUTPUT_DIR, f"predictions_{TARGET_COL}_{timestamp}.csv")

        # Get predictions and labels from test loader
        model.eval()
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for X, y in test_loader:
                X = X.to(device)
                output = model(X)
                probs = torch.sigmoid(output)
                all_preds.extend(probs.cpu().numpy())
                all_labels.extend(y.cpu().numpy())

        # Save predictions with target values
        results_df = pd.DataFrame({
            'predictions': all_preds,
            'targets': all_labels
        })
        results_df.to_csv(output_path, index=False)
        print(f"\nSaved predictions to {output_path}")

        # Log test metrics to tensorboard
        for metric_name, metric_value in test_metrics.items():
            writer.add_scalar(f'test/{metric_name}', metric_value, 0)
    else:
        print("Skipping test evaluation.")

    return model

def train_traditional_baselines():
    """Train and evaluate traditional baseline models (Logistic, XGBoost)"""
    start_time = time.time()
    writer = SummaryWriter(LOG_DIR)

    # Early stopping parameters
    early_stopping_patience = 10
    early_stopping_counter = 0
    best_val_auc = 0

    # Create dataloaders with validation data path
    print("\nCreating dataloaders for traditional models...")
    dataloader_start = time.time()

    # Use multi-target dataloaders if TARGET_COL is defined, otherwise use regular dataloaders
    if 'TARGET_COL' in globals() and TARGET_COL:
        print(f"Using multi-target dataloaders with target column: {TARGET_COL}")
        train_loader, val_loader = create_multi_target_baseline_dataloaders(
            TRAIN_DATA_PATH,
            target_col=TARGET_COL,
            batch_size=BATCH_SIZE,
            train_ratio=TRAIN_RATIO,
            val_data_path=VAL_DATA_PATH
        )
        is_prob_target = '_prob' in TARGET_COL
    else:
        print("Using standard dataloaders")
        train_loader, val_loader = create_baseline_dataloaders(
            TRAIN_DATA_PATH,
            batch_size=BATCH_SIZE,
            train_ratio=TRAIN_RATIO,
            val_data_path=VAL_DATA_PATH
        )
        is_prob_target = False

    print(f"Dataloader creation took {time.time() - dataloader_start:.2f} seconds")

    # Check class distribution
    train_labels = np.concatenate([y.numpy() for _, y in train_loader])
    val_labels = np.concatenate([y.numpy() for _, y in val_loader])
    print(f"Training class distribution: {np.bincount(train_labels.astype(int))}")
    print(f"Validation class distribution: {np.bincount(val_labels.astype(int))}")

    # Get input size from the first batch
    sample_batch = next(iter(train_loader))
    input_size = sample_batch[0].shape[1]
    print(f"Input feature size: {input_size}")

    # Initialize traditional models
    baselines = BaselineModels(
        pos_weight=BASE_POS_WEIGHT,
        fn_penalty=FN_PENALTY,
        fp_penalty=FP_PENALTY
    )

    # Train and evaluate traditional models
    val_metrics_history = []
    models_info = [
        ('Logistic', 'logistic'),
        ('XGBoost', 'xgb')
    ]

    for display_name, model_name in models_info:
        print(f"\nTraining {display_name}...")
        model_start = time.time()

        # Prepare training data
        X_train = np.concatenate([X.numpy() for X, _ in train_loader])
        y_train = np.concatenate([y.numpy() for _, y in train_loader])

        # Get the appropriate fit function
        if model_name == 'logistic':
            fit_fn = lambda X, y: baselines.fit_logistic(X, y, is_prob_target=is_prob_target)
        else:
            fit_fn = lambda X, y: baselines.fit_xgb(X, y, is_prob_target=is_prob_target)

        # Train model
        fit_fn(X_train, y_train)
        print(f"{display_name} training took {time.time() - model_start:.2f} seconds")

        # Create prediction function
        predict_fn = lambda X: baselines.predict(model_name, X)

        # Evaluate on validation set
        val_metrics = evaluate_baseline(predict_fn, val_loader,
                                     add_noise=EVAL_ADD_NOISE,
                                     noise_std=EVAL_NOISE_STD,
                                     is_prob_target=is_prob_target)
        val_metrics_history.append(val_metrics)

        # Log metrics
        for metric_name, metric_value in val_metrics.items():
            writer.add_scalar(f'{model_name.lower()}/val/{metric_name}', metric_value, 0)

        # Save model if it's the best so far
        if val_metrics['auc_pr'] > best_val_auc:
            best_val_auc = val_metrics['auc_pr']
            model_save_path = os.path.join(LOG_DIR, f'best_{model_name.lower()}_{MODEL_NAME}.pt')
            baselines.save_model(model_name, model_save_path)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        # Print validation metrics in detail
        print(f"\n{display_name} Validation Metrics:")
        print(f"AUC-PR = {val_metrics['auc_pr']:.4f} (fixed: {val_metrics['fixed_auc_pr']:.4f})")
        print(f"Accuracy = {val_metrics['accuracy']:.4f} (fixed: {val_metrics['fixed_accuracy']:.4f})")
        print(f"Recall = {val_metrics['recall']:.4f} (fixed: {val_metrics['fixed_recall']:.4f})")
        print(f"Precision = {val_metrics['precision']:.4f} (fixed: {val_metrics['fixed_precision']:.4f})")
        print(f"F1 = {val_metrics['f1']:.4f} (fixed: {val_metrics['fixed_f1']:.4f})")
        print(f"FNR = {val_metrics['false_negative_rate']:.4f} (fixed: {val_metrics['fixed_fnr']:.4f})")
        print(f"FPR = {val_metrics['false_positive_rate']:.4f} (fixed: {val_metrics['fixed_fpr']:.4f})")
        print(f"opt Threshold = {val_metrics['optimal_threshold']:.4f}")
        print(f"TP = {val_metrics['true_positives']} (fixed: {val_metrics['fixed_true_positives']})")
        print(f"TN = {val_metrics['true_negatives']} (fixed: {val_metrics['fixed_true_negatives']})")
        print(f"FP = {val_metrics['false_positives']} (fixed: {val_metrics['fixed_false_positives']})")
        print(f"FN = {val_metrics['false_negatives']} (fixed: {val_metrics['fixed_false_negatives']})")
        print(f"Uncertain%: {val_metrics['uncertain_ratio']:.4f}")
        print(f"Uncertain Acc: {val_metrics['uncertain_accuracy']:.4f}")

    if NEED_TEST_DATA:
        # Create and evaluate on test set
        if 'TARGET_COL' in globals() and TARGET_COL:
            test_loader = create_multi_target_baseline_test_loader(
                TEST_DATA_PATH,
                target_col=TARGET_COL,
                batch_size=BATCH_SIZE
            )
        else:
            test_loader = create_baseline_test_loader(
                TEST_DATA_PATH,
                batch_size=BATCH_SIZE
            )

        print("\nEvaluating traditional models on test set...")
        for display_name, model_name in models_info:
            predict_fn = lambda X: baselines.predict(model_name, X)
            test_metrics = evaluate_baseline(predict_fn, test_loader,
                                          add_noise=EVAL_ADD_NOISE,
                                          noise_std=EVAL_NOISE_STD,
                                          is_prob_target=is_prob_target)

            # Log test metrics
            print(f"\n{display_name} Test Metrics:")
            for metric_name, metric_value in test_metrics.items():
                writer.add_scalar(f'{model_name.lower()}/test/{metric_name}', metric_value, 0)
                print(f"{metric_name}: {metric_value:.4f}")
    else:
        print("Skipping test evaluation.")

    total_time = time.time() - start_time
    print(f"\nTotal traditional models training time: {total_time:.2f} seconds")

    return baselines

def train_nn_baselines():
    """Train and evaluate neural network baseline models (MLP, CNN)"""
    start_time = time.time()
    writer = SummaryWriter(LOG_DIR)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device} for neural network training")

    # Early stopping parameters
    early_stopping_patience = 10
    early_stopping_counter = 0
    best_val_auc = 0

    # Create dataloaders with validation data path
    print("\nCreating dataloaders for neural network models...")
    dataloader_start = time.time()

    # Use multi-target dataloaders if TARGET_COL is defined, otherwise use regular dataloaders
    if 'TARGET_COL' in globals() and TARGET_COL:
        print(f"Using multi-target dataloaders with target column: {TARGET_COL}")
        train_loader, val_loader = create_multi_target_baseline_dataloaders(
            TRAIN_DATA_PATH,
            target_col=TARGET_COL,
            batch_size=BATCH_SIZE,
            train_ratio=TRAIN_RATIO,
            val_data_path=VAL_DATA_PATH
        )
        is_prob_target = '_prob' in TARGET_COL
    else:
        print("Using standard dataloaders")
        train_loader, val_loader = create_baseline_dataloaders(
            TRAIN_DATA_PATH,
            batch_size=BATCH_SIZE,
            train_ratio=TRAIN_RATIO,
            val_data_path=VAL_DATA_PATH
        )
        is_prob_target = False

    print(f"Dataloader creation took {time.time() - dataloader_start:.2f} seconds")

    # Check class distribution
    train_labels = np.concatenate([y.numpy() for _, y in train_loader])
    val_labels = np.concatenate([y.numpy() for _, y in val_loader])
    print(f"Training class distribution: {np.bincount(train_labels.astype(int))}")
    print(f"Validation class distribution: {np.bincount(val_labels.astype(int))}")

    # Get input size from the first batch
    sample_batch = next(iter(train_loader))
    input_size = sample_batch[0].shape[1]
    print(f"Input feature size: {input_size}")

    # Initialize neural network models
    nn_baselines = NNBaselineModels(
        input_size=input_size,
        pos_weight=BASE_POS_WEIGHT,
        fn_penalty=FN_PENALTY,
        fp_penalty=FP_PENALTY
    )

    # Train and evaluate neural network models
    val_metrics_history = []
    nn_models_info = [
        ('MLP', 'mlp'),
        ('CNN', 'cnn')
    ]

    for display_name, model_name in nn_models_info:
        print(f"\nTraining {display_name}...")
        model_start = time.time()

        # Train model using the dataloader directly
        nn_baselines.fit_model(
            model_name,
            train_loader,
            val_loader,
            epochs=NN_EPOCHS if 'NN_EPOCHS' in globals() else 10,
            lr=NN_LEARNING_RATE if 'NN_LEARNING_RATE' in globals() else 0.001
        )
        print(f"{display_name} training took {time.time() - model_start:.2f} seconds")

        # Create prediction function
        predict_fn = lambda X: nn_baselines.predict(model_name, X)

        # Evaluate on validation set
        val_metrics = evaluate_baseline(predict_fn, val_loader,
                                     add_noise=EVAL_ADD_NOISE,
                                     noise_std=EVAL_NOISE_STD,
                                     is_prob_target=is_prob_target)
        val_metrics_history.append(val_metrics)

        # Log metrics
        for metric_name, metric_value in val_metrics.items():
            writer.add_scalar(f'{model_name.lower()}/val/{metric_name}', metric_value, 0)

        # Save model if it's the best so far
        if val_metrics['auc_pr'] > best_val_auc:
            best_val_auc = val_metrics['auc_pr']
            model_save_path = os.path.join(LOG_DIR, f'best_{model_name.lower()}_{MODEL_NAME}.pt')
            nn_baselines.save_model(model_name, model_save_path)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        # Print validation metrics in detail
        print(f"\n{display_name} Validation Metrics:")
        print(f"AUC-PR = {val_metrics['auc_pr']:.4f} (fixed: {val_metrics['fixed_auc_pr']:.4f})")
        print(f"Accuracy = {val_metrics['accuracy']:.4f} (fixed: {val_metrics['fixed_accuracy']:.4f})")
        print(f"Recall = {val_metrics['recall']:.4f} (fixed: {val_metrics['fixed_recall']:.4f})")
        print(f"Precision = {val_metrics['precision']:.4f} (fixed: {val_metrics['fixed_precision']:.4f})")
        print(f"F1 = {val_metrics['f1']:.4f} (fixed: {val_metrics['fixed_f1']:.4f})")
        print(f"FNR = {val_metrics['false_negative_rate']:.4f} (fixed: {val_metrics['fixed_fnr']:.4f})")
        print(f"FPR = {val_metrics['false_positive_rate']:.4f} (fixed: {val_metrics['fixed_fpr']:.4f})")
        print(f"opt Threshold = {val_metrics['optimal_threshold']:.4f}")
        print(f"TP = {val_metrics['true_positives']} (fixed: {val_metrics['fixed_true_positives']})")
        print(f"TN = {val_metrics['true_negatives']} (fixed: {val_metrics['fixed_true_negatives']})")
        print(f"FP = {val_metrics['false_positives']} (fixed: {val_metrics['fixed_false_positives']})")
        print(f"FN = {val_metrics['false_negatives']} (fixed: {val_metrics['fixed_false_negatives']})")
        print(f"Uncertain%: {val_metrics['uncertain_ratio']:.4f}")
        print(f"Uncertain Acc: {val_metrics['uncertain_accuracy']:.4f}")

    if NEED_TEST_DATA:
        # Create and evaluate on test set
        if 'TARGET_COL' in globals() and TARGET_COL:
            test_loader = create_multi_target_baseline_test_loader(
                TEST_DATA_PATH,
                target_col=TARGET_COL,
                batch_size=BATCH_SIZE
            )
        else:
            test_loader = create_baseline_test_loader(
                TEST_DATA_PATH,
                batch_size=BATCH_SIZE
            )

        print("\nEvaluating neural network models on test set...")
        for display_name, model_name in nn_models_info:
            predict_fn = lambda X: nn_baselines.predict(model_name, X)
            test_metrics = evaluate_baseline(predict_fn, test_loader,
                                          add_noise=EVAL_ADD_NOISE,
                                          noise_std=EVAL_NOISE_STD,
                                          is_prob_target=is_prob_target)

            # Log test metrics
            print(f"\n{display_name} Test Metrics:")
            for metric_name, metric_value in test_metrics.items():
                writer.add_scalar(f'{model_name.lower()}/test/{metric_name}', metric_value, 0)
                print(f"{metric_name}: {metric_value:.4f}")
    else:
        print("Skipping test evaluation.")

    total_time = time.time() - start_time
    print(f"\nTotal neural network models training time: {total_time:.2f} seconds")

    return nn_baselines

def evaluate_baseline(predict_fn, dataloader, add_noise=False, noise_std=1e-3, is_prob_target=False):
    """Evaluate baseline model"""
    all_preds = []
    all_labels = []

    for X, y in dataloader:
        # Handle both numpy arrays and tensors
        if isinstance(X, torch.Tensor):
            X_data = X  # Keep as tensor for NN models
            X_numpy = X.numpy()  # Convert to numpy for traditional models
        else:
            X_data = X
            X_numpy = X

        if isinstance(y, torch.Tensor):
            y_numpy = y.numpy()
        else:
            y_numpy = y

        # Try to use tensor input first (for NN models)
        try:
            pred = predict_fn(X_data)
        except:
            # Fall back to numpy input for traditional models
            pred = predict_fn(X_numpy)

        # Convert predictions to numpy if they're tensors
        if isinstance(pred, torch.Tensor):
            pred = pred.cpu().numpy()

        all_preds.extend(pred)
        all_labels.extend(y_numpy)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Log prediction distribution around 0.5
    exact_half = np.mean(all_preds == 0.5) * 100
    near_half_range = 0.01
    near_half = np.mean((all_preds > 0.5 - near_half_range) &
                       (all_preds < 0.5 + near_half_range)) * 100

    print(f"\nPrediction distribution analysis:")
    print(f"Predictions exactly 0.5: {exact_half:.2f}%")
    print(f"Predictions within ±{near_half_range} of 0.5: {near_half:.2f}%")

    # Optional noise addition
    if add_noise and exact_half > 0:
        print(f"Adding noise (std={noise_std}) to break symmetry")
        all_preds = all_preds + np.random.normal(0, noise_std, all_preds.shape)

    return calculate_metrics(all_preds, all_labels, is_prob_target=is_prob_target)

In [None]:
print("Starting training...")
print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"Train Data path: {TRAIN_DATA_PATH}")
print(f"Val Data path: {VAL_DATA_PATH}")
print(f"Test Data path: {TEST_DATA_PATH}")
print(f"Log directory: {LOG_DIR}")
print(f"Target column: {TARGET_COL if 'TARGET_COL' in globals() else 'Default'}")

if model_type == 'transformer':
    print("\n=== Training Transformer Model ===")
    model = train_transformer()
elif model_type == 'traditional':
    print("\n=== Training Traditional Baseline Models ===")
    traditional_models = train_traditional_baselines()
elif model_type == 'nn':
    print("\n=== Training Neural Network Baseline Models ===")
    nn_models = train_nn_baselines()
elif model_type == 'all':
    print("\n=== Training All Baseline Models ===")
    traditional_models = train_traditional_baselines()
    nn_models = train_nn_baselines()
else:
    print("Invalid model type. Please choose 'transformer', 'traditional', 'nn', or 'all'.")

In [None]:
runtime.unassign()