In [None]:
import sys
from models import ASTModel
import dataloader
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import json
import random
from collections import defaultdict
import argparse
from traintest import train, validate

sys.path.append('./src') 
os.environ['TORCH_HOME'] = '../pretrained_models'  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def prepare_ast_dataset(root_path, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
    # ‡∏ï‡∏£‡∏ß‡∏à‡∏™‡∏≠‡∏ö‡∏ß‡πà‡∏≤‡∏™‡∏±‡∏î‡∏™‡πà‡∏ß‡∏ô‡∏£‡∏ß‡∏°‡∏Å‡∏±‡∏ô‡πÑ‡∏î‡πâ 1.0 ‡∏´‡∏£‡∏∑‡∏≠‡πÑ‡∏°‡πà
    assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-5, "Ratios must sum to 1.0"

    # 1. ‡∏ô‡∏¥‡∏¢‡∏≤‡∏°‡πÇ‡∏Ñ‡∏£‡∏á‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡πÅ‡∏•‡∏∞ Class Map
    patient_groups = {
        "Cough_PTB": defaultdict(list),
        "Cough_Non-PTB": defaultdict(list)
    }
    class_map = {"Cough_PTB": "1", "Cough_Non-PTB": "0"}

    # 2. ‡∏£‡∏ß‡∏ö‡∏£‡∏ß‡∏°‡πÑ‡∏ü‡∏•‡πå‡πÄ‡∏™‡∏µ‡∏¢‡∏á‡πÅ‡∏•‡∏∞‡∏à‡∏±‡∏î‡∏Å‡∏•‡∏∏‡πà‡∏°‡∏ï‡∏≤‡∏° Patient ID
    for folder_name, label_idx in class_map.items():
        folder_path = os.path.join(root_path, folder_name)
        if not os.path.exists(folder_path):
            print(f"‚ö†Ô∏è Warning: Folder {folder_path} not found.")
            continue
            
        for file in os.listdir(folder_path):
            if file.endswith(".wav"):
                patient_id = file.split('_')[0]
                full_path = os.path.abspath(os.path.join(folder_path, file))
                patient_groups[folder_name][patient_id].append(full_path)

    # 3. ‡∏£‡∏ß‡∏ö‡∏£‡∏ß‡∏°‡∏£‡∏≤‡∏¢‡∏ä‡∏∑‡πà‡∏≠‡∏ú‡∏π‡πâ‡∏õ‡πà‡∏ß‡∏¢‡∏ó‡∏±‡πâ‡∏á‡∏´‡∏°‡∏î‡πÅ‡∏•‡∏∞‡∏™‡∏∏‡πà‡∏°‡πÅ‡∏ö‡πà‡∏á 3 ‡∏Å‡∏•‡∏∏‡πà‡∏°
    all_patients = set()
    for folder_name in patient_groups:
        all_patients.update(patient_groups[folder_name].keys())
    
    all_patients = list(all_patients)
    random.seed(42) # ‡∏Å‡∏≥‡∏´‡∏ô‡∏î seed ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡πÉ‡∏´‡πâ‡∏ú‡∏•‡∏Å‡∏≤‡∏£‡∏™‡∏∏‡πà‡∏°‡∏Ñ‡∏á‡πÄ‡∏î‡∏¥‡∏°
    random.shuffle(all_patients)

    total_patients = len(all_patients)
    train_split_idx = int(total_patients * train_ratio)
    val_split_idx = train_split_idx + int(total_patients * val_ratio)

    train_ids = all_patients[:train_split_idx]
    val_ids = all_patients[train_split_idx:val_split_idx]
    test_ids = all_patients[val_split_idx:]

    # 4. ‡∏à‡∏±‡∏î‡πÄ‡∏ï‡∏£‡∏µ‡∏¢‡∏°‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡∏•‡∏á List ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö JSON ‡πÅ‡∏•‡∏∞‡πÅ‡∏¢‡∏Å ID ‡∏ï‡∏≤‡∏°‡∏Ñ‡∏•‡∏≤‡∏™‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏Å‡∏≤‡∏£‡πÅ‡∏™‡∏î‡∏á‡∏ú‡∏•
    train_list, val_list, test_list = [], [], []
    
    # ‡∏ï‡∏±‡∏ß‡πÅ‡∏õ‡∏£‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö‡πÄ‡∏Å‡πá‡∏ö ID ‡πÅ‡∏¢‡∏Å‡∏Ñ‡∏•‡∏≤‡∏™‡πÄ‡∏û‡∏∑‡πà‡∏≠ Print Summary
    summary_ids = {
        "train_ptb": set(), "train_non_ptb": set(),
        "val_ptb": set(), "val_non_ptb": set(),
        "test_ptb": set(), "test_non_ptb": set()
    }

    for folder_name, groups in patient_groups.items():
        is_ptb = (folder_name == "Cough_PTB")
        
        for p_id, paths in groups.items():
            for path in paths:
                item = {"wav": path, "labels": class_map[folder_name]}
                
                if p_id in train_ids:
                    train_list.append(item)
                    if is_ptb: summary_ids["train_ptb"].add(p_id)
                    else: summary_ids["train_non_ptb"].add(p_id)
                elif p_id in val_ids:
                    val_list.append(item)
                    if is_ptb: summary_ids["val_ptb"].add(p_id)
                    else: summary_ids["val_non_ptb"].add(p_id)
                elif p_id in test_ids:
                    test_list.append(item)
                    if is_ptb: summary_ids["test_ptb"].add(p_id)
                    else: summary_ids["test_non_ptb"].add(p_id)

    # 5. ‡∏ö‡∏±‡∏ô‡∏ó‡∏∂‡∏Å‡πÑ‡∏ü‡∏•‡πå JSON (‡∏´‡∏±‡∏ß‡πÉ‡∏à‡∏™‡∏≥‡∏Ñ‡∏±‡∏ç‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö dataloader.py ‡∏Ç‡∏≠‡∏á AST)
    with open('train_data.json', 'w') as f:
        json.dump({"data": train_list}, f, indent=4)
    with open('val_data.json', 'w') as f:
        json.dump({"data": val_list}, f, indent=4)
    with open('test_data.json', 'w') as f:
        json.dump({"data": test_list}, f, indent=4)
    def count_files(data_list):
        ptb_count = sum(1 for item in data_list if item['labels'] == "1")
        non_ptb_count = sum(1 for item in data_list if item['labels'] == "0")
        return ptb_count, non_ptb_count

    train_ptb_audio, train_non_ptb_audio = count_files(train_list)
    val_ptb_audio, val_non_ptb_audio = count_files(val_list)
    test_ptb_audio, test_non_ptb_audio = count_files(test_list)
    # 6. ‡πÅ‡∏™‡∏î‡∏á‡∏ú‡∏•‡∏™‡∏£‡∏∏‡∏õ‡∏≠‡∏¢‡πà‡∏≤‡∏á‡∏•‡∏∞‡πÄ‡∏≠‡∏µ‡∏¢‡∏î
    print(" AST DATASET PREPARATION SUMMARY (3-Way Patient-Level Split)")
    print("=" * 70)
    
    print(f"üîπ TRAIN SET (Total Patients: {len(train_ids)})")
    print(f"    PTB ID: {sorted(list(summary_ids['train_ptb']))}")
    print(f"    Non-PTB ID: {sorted(list(summary_ids['train_non_ptb']))}")
    print(f"    Total Audio Files: {len(train_list)} (PTB: {train_ptb_audio} files | Non-PTB: {train_non_ptb_audio} files)")  
    
    print(f"üî∏ VALIDATION SET (Total Patients: {len(val_ids)})")
    print(f"    PTB ID: {sorted(list(summary_ids['val_ptb']))}")
    print(f"    Non-PTB ID: {sorted(list(summary_ids['val_non_ptb']))}")
    print(f"    Total Audio Files: {len(val_list)} (PTB: {val_ptb_audio} files | Non-PTB: {val_non_ptb_audio} files)")   
    
    print(f"üî¥ TEST SET (Total Patients: {len(test_ids)})")
    print(f"    PTB ID: {sorted(list(summary_ids['test_ptb']))}")
    print(f"    Non-PTB ID: {sorted(list(summary_ids['test_non_ptb']))}")
    print(f"    Total Audio Files: {len(test_list)} (PTB: {test_ptb_audio} files | Non-PTB: {test_non_ptb_audio} files)")
    print("=" * 70)

# ‡πÄ‡∏£‡∏µ‡∏¢‡∏Å‡πÉ‡∏ä‡πâ‡∏á‡∏≤‡∏ô: ‡πÅ‡∏ö‡πà‡∏á Train 70%, Validate 10%, Test 20%
prepare_ast_dataset("./Data", train_ratio=0.6, val_ratio=0.2, test_ratio=0.2)

 AST DATASET PREPARATION SUMMARY (3-Way Patient-Level Split)
üîπ TRAIN SET (Total Patients: 9)
    PTB ID: ['006', '008', '012', '015']
    Non-PTB ID: ['001', '002', '003', '013', '014']
    Total Audio Files: 392 (PTB: 151 files | Non-PTB: 241 files)
üî∏ VALIDATION SET (Total Patients: 3)
    PTB ID: ['009', '011']
    Non-PTB ID: ['005']
    Total Audio Files: 37 (PTB: 5 files | Non-PTB: 32 files)
üî¥ TEST SET (Total Patients: 3)
    PTB ID: ['016']
    Non-PTB ID: ['004', '007']
    Total Audio Files: 83 (PTB: 31 files | Non-PTB: 52 files)


In [5]:
print("üöÄ STARTING AST-P SINGLE RUN (80/20 SPLIT) WITH VISUALIZATION")
# 1. ‡∏Å‡∏≥‡∏´‡∏ô‡∏î‡πÇ‡∏ü‡∏•‡πÄ‡∏î‡∏≠‡∏£‡πå‡πÄ‡∏Å‡πá‡∏ö‡∏ú‡∏•‡∏•‡∏±‡∏û‡∏ò‡πå
exp_dir = './exp/tb_ast_p_single_run'
os.makedirs(f'{exp_dir}/models', exist_ok=True)
MEAN_NORM = -3.2406702
STD_NORM = 4.9710765
# 2. ‡∏ï‡∏±‡πâ‡∏á‡∏Ñ‡πà‡∏≤ Data Config
train_audio_conf = {'num_mel_bins': 128, 'target_length': 100, 'freqm': 24, 'timem': 20, 'mixup': 0.0, 'dataset': 'audioset', 'mode': 'train', 'mean': MEAN_NORM, 'std': STD_NORM, 'noise': False, 'skip_norm': False}
eval_audio_conf = {'num_mel_bins': 128, 'target_length': 100, 'freqm': 0, 'timem': 0, 'mixup': 0.0, 'dataset': 'audioset', 'mode': 'evaluation', 'mean': MEAN_NORM, 'std': STD_NORM, 'noise': False, 'skip_norm': False}
BATCH_SIZE = 8
# 3. ‡∏™‡∏£‡πâ‡∏≤‡∏á DataLoader
train_loader = torch.utils.data.DataLoader(
    dataloader.AudiosetDataset('train_data.json', label_csv='class_labels_indices.csv', audio_conf=train_audio_conf),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
)
eval_loader = torch.utils.data.DataLoader(
    dataloader.AudiosetDataset('eval_data.json', label_csv='class_labels_indices.csv', audio_conf=eval_audio_conf),
    batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=0, pin_memory=True
)
# ‡πÅ‡∏™‡∏î‡∏á‡∏†‡∏≤‡∏û‡∏ï‡∏±‡∏ß‡∏≠‡∏¢‡πà‡∏≤‡∏á Spectrogram
# ‡∏î‡∏∂‡∏á‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡∏°‡∏≤ 1 Batch ‡∏à‡∏≤‡∏Å train_loader ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏î‡∏π‡∏†‡∏≤‡∏û
sample_inputs, sample_labels = next(iter(train_loader))
# sample_inputs ‡∏à‡∏∞‡∏°‡∏µ‡∏Ç‡∏ô‡∏≤‡∏î [batch_size, temporal_frame_num, frequency_bin_num]
# ‡πÄ‡∏£‡∏≤‡∏î‡∏∂‡∏á index 0 ‡∏Ç‡∏≠‡∏á Batch ‡∏°‡∏≤‡∏û‡∏•‡πá‡∏≠‡∏ï
spec_data = sample_inputs[0].numpy()  
label_data = torch.argmax(sample_labels[0]).item()
class_name = "PTB" if label_data == 1 else "Non-PTB"

plt.figure(figsize=(10, 4))
# ‡∏û‡∏•‡πá‡∏≠‡∏ï‡πÇ‡∏î‡∏¢‡∏™‡∏•‡∏±‡∏ö‡πÅ‡∏Å‡∏ô (Transpose) ‡πÉ‡∏´‡πâ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ñ‡∏µ‡πà (128) ‡∏≠‡∏¢‡∏π‡πà‡πÅ‡∏Å‡∏ô Y ‡πÅ‡∏•‡∏∞‡πÄ‡∏ß‡∏•‡∏≤ (100) ‡∏≠‡∏¢‡∏π‡πà‡πÅ‡∏Å‡∏ô X
plt.imshow(spec_data.T, aspect='auto', origin='lower', cmap='viridis')
plt.title(f"Example Log Mel Filterbank Spectrogram (Class: {class_name})")
plt.ylabel("Frequency Bins (128)")
plt.xlabel("Time Frames (100)")
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
plt.show()
print("‚òùÔ∏è ‡∏†‡∏≤‡∏û‡∏î‡πâ‡∏≤‡∏ô‡∏ö‡∏ô‡∏Ñ‡∏∑‡∏≠‡∏ü‡∏µ‡πÄ‡∏à‡∏≠‡∏£‡πå Spectrogram 128 ‡∏°‡∏¥‡∏ï‡∏¥‡∏ó‡∏µ‡πà‡∏ú‡πà‡∏≤‡∏ô‡∏Å‡∏≤‡∏£‡∏ó‡∏≥ Normalization ‡πÅ‡∏•‡πâ‡∏ß ‡∏ã‡∏∂‡πà‡∏á‡∏à‡∏∞‡∏ñ‡∏π‡∏Å‡∏™‡πà‡∏á‡πÄ‡∏Ç‡πâ‡∏≤ AST Model")
# ==========================================

# 4. ‡∏™‡∏£‡πâ‡∏≤‡∏á AST-P Model 
ast_model = ASTModel(
    label_dim=2, 
    fstride=10, 
    tstride=10, 
    input_fdim=128, 
    input_tdim=100, 
    imagenet_pretrain=True, 
    audioset_pretrain=True, 
    model_size='base384'
)

# 5. ‡∏ï‡∏±‡πâ‡∏á‡∏Ñ‡πà‡∏≤ Training Arguments
args = argparse.Namespace(
    exp_dir=exp_dir, dataset='audioset', n_class=2, lr=1e-5, n_epochs=30, batch_size=BATCH_SIZE, 
    n_print_steps=10, save_model=True, loss='CE', metrics='mAP', 
    lrscheduler_start=10, lrscheduler_step=5, lrscheduler_decay=0.5, 
    warmup=True, wa=True, wa_start=1, wa_end=30
)

# 6. ‡πÄ‡∏£‡∏¥‡πà‡∏°‡∏Å‡∏£‡∏∞‡∏ö‡∏ß‡∏ô‡∏Å‡∏≤‡∏£ Train
print("\nStarting Model Training...")
train(ast_model, train_loader, eval_loader, args)

# 7. ‡πÇ‡∏´‡∏•‡∏î Best Model ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏õ‡∏£‡∏∞‡πÄ‡∏°‡∏¥‡∏ô‡∏ú‡∏•
best_model_path = f'{args.exp_dir}/models/best_audio_model.pth'
best_model = ASTModel(label_dim=2, fstride=10, tstride=10, input_fdim=128, input_tdim=100, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
best_model = torch.nn.DataParallel(best_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = best_model.to(device)

best_model.load_state_dict(torch.load(best_model_path, map_location=device))

stats, eval_loss = validate(best_model, eval_loader, args, epoch='best')


üöÄ STARTING AST-P SINGLE RUN (80/20 SPLIT) WITH VISUALIZATION


NameError: name 'torch' is not defined

In [None]:
# ==========================================
# üìä [NEW] ‡∏ß‡∏≤‡∏î Confusion Matrix ‡∏à‡∏≤‡∏Å‡∏ä‡∏∏‡∏î Eval
# ==========================================
from sklearn.metrics import confusion_matrix


print("\nüìà Generating Confusion Matrix...")
all_preds = []
all_targets = []

best_model.eval() # ‡πÄ‡∏Ç‡πâ‡∏≤‡∏™‡∏π‡πà‡πÇ‡∏´‡∏°‡∏î‡∏ó‡∏î‡∏™‡∏≠‡∏ö
with torch.no_grad():
    for audio, labels in eval_loader:
        audio = audio.to(device)
        # AST Output ‡πÄ‡∏õ‡πá‡∏ô Raw Logits (‡πÑ‡∏°‡πà‡∏°‡∏µ Sigmoid/Softmax)
        logits = best_model(audio) 
        # ‡πÉ‡∏ä‡πâ Argmax ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏´‡∏≤‡∏Ñ‡∏•‡∏≤‡∏™‡∏ó‡∏µ‡πà‡∏°‡∏µ‡∏Ñ‡πà‡∏≤ Logit ‡∏™‡∏π‡∏á‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        
        all_preds.extend(preds)
        
        # ‚úÖ ‡πÅ‡∏Å‡πâ‡πÑ‡∏Ç‡∏ö‡∏£‡∏£‡∏ó‡∏±‡∏î‡∏ô‡∏µ‡πâ: ‡πÅ‡∏õ‡∏•‡∏á labels ‡πÉ‡∏´‡πâ‡πÄ‡∏õ‡πá‡∏ô‡∏ï‡∏±‡∏ß‡πÄ‡∏•‡∏Ç‡πÄ‡∏î‡∏µ‡πà‡∏¢‡∏ß (0 ‡∏´‡∏£‡∏∑‡∏≠ 1) ‡∏Å‡πà‡∏≠‡∏ô‡πÄ‡∏Å‡πá‡∏ö‡∏•‡∏á‡∏•‡∏¥‡∏™‡∏ï‡πå
        all_targets.extend(torch.argmax(labels, dim=1).numpy())

# ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Å‡∏£‡∏≤‡∏ü Confusion Matrix
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(6, 5))
sys.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-PTB (0)', 'PTB (1)'], 
            yticklabels=['Non-PTB (0)', 'PTB (1)'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('AST-P Confusion Matrix (Eval Set)')
plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    roc_curve, auc, average_precision_score, confusion_matrix
)

# ==========================================
# 1. ‡∏´‡∏≤ Best Epoch ‡∏à‡∏≤‡∏Å‡πÑ‡∏ü‡∏•‡πå result.csv ‡∏Ç‡∏≠‡∏á AST
# ==========================================
# ‡πÇ‡∏Ñ‡∏£‡∏á‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ñ‡∏≠‡∏•‡∏±‡∏°‡∏ô‡πå‡∏°‡∏≤‡∏ï‡∏£‡∏ê‡∏≤‡∏ô‡∏Ç‡∏≠‡∏á AST: [mAP, mAUC, precision, recall, d_prime, train_loss, valid_loss, cum_mAP, cum_mAUC, lr]
csv_path = './exp/tb_ast_p_single_run/result.csv'
df_results = pd.read_csv(csv_path, header=None)

# ‡∏Ñ‡∏≠‡∏•‡∏±‡∏°‡∏ô‡πå index 1 ‡∏Ñ‡∏∑‡∏≠ mAUC ‡∏Ç‡∏≠‡∏á‡πÅ‡∏ï‡πà‡∏•‡∏∞ Epoch
best_epoch = df_results[1].idxmax() + 1
best_mauc = df_results[1].max()
print(f"üåü Best Epoch: {best_epoch} (Validation AUC from log: {best_mauc:.4f})\n")

# ==========================================
# 2. ‡πÄ‡∏Å‡πá‡∏ö‡∏ú‡∏•‡∏•‡∏±‡∏û‡∏ò‡πå‡πÅ‡∏ö‡∏ö‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô (Probabilities) ‡∏à‡∏≤‡∏Å Best Model
# ==========================================
print(" Evaluating Best Model on Evaluation Set...")
all_probs = []
all_targets = []

best_model.eval() # ‡πÄ‡∏Ç‡πâ‡∏≤‡∏™‡∏π‡πà‡πÇ‡∏´‡∏°‡∏î‡∏ó‡∏î‡∏™‡∏≠‡∏ö
with torch.no_grad():
    for audio, labels in eval_loader:
        audio = audio.to(device)
        logits = best_model(audio) 
        
        # AST Output ‡πÄ‡∏õ‡πá‡∏ô Raw Logits (‡∏°‡∏µ 2 ‡∏Ñ‡πà‡∏≤ ‡∏ï‡∏≤‡∏°‡∏à‡∏≥‡∏ô‡∏ß‡∏ô n_class=2)
        # ‡πÄ‡∏£‡∏≤‡πÉ‡∏ä‡πâ Softmax ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡πÅ‡∏õ‡∏•‡∏á‡πÄ‡∏õ‡πá‡∏ô‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô (0-1) ‡πÅ‡∏•‡∏∞‡πÄ‡∏•‡∏∑‡∏≠‡∏Å‡πÄ‡∏â‡∏û‡∏≤‡∏∞‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô‡∏Ç‡∏≠‡∏á‡∏Ñ‡∏•‡∏≤‡∏™ 1 (PTB)
        probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
        
        # ‡∏î‡∏∂‡∏á Label ‡∏Ñ‡∏•‡∏≤‡∏™‡∏à‡∏£‡∏¥‡∏á
        targets = torch.argmax(labels, dim=1).cpu().numpy()
        
        all_probs.extend(probs)
        all_targets.extend(targets)

all_probs = np.array(all_probs)
all_targets = np.array(all_targets)

# ==========================================
# 3. ‡∏Ñ‡∏≥‡∏ô‡∏ß‡∏ì‡∏Ñ‡∏ß‡∏≤‡∏°‡πÅ‡∏°‡πà‡∏ô‡∏¢‡∏≥ AUROC, AUPRC ‡πÅ‡∏•‡∏∞‡∏´‡∏≤ Best Threshold
# ==========================================
fpr, tpr, roc_thresholds = roc_curve(all_targets, all_probs)
auroc_val = auc(fpr, tpr)
auprc_val = average_precision_score(all_targets, all_probs)


# ==========================================
# 4. ‡∏Ñ‡∏≥‡∏ô‡∏ß‡∏ì Sensitivity ‡πÅ‡∏•‡∏∞ Specificity ‡∏ó‡∏µ‡πà‡∏à‡∏∏‡∏î Best Threshold
# ==========================================
# ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ñ‡∏≥‡∏ó‡∏≥‡∏ô‡∏≤‡∏¢‡πÉ‡∏´‡∏°‡πà‡πÇ‡∏î‡∏¢‡∏≠‡∏¥‡∏á‡∏à‡∏≤‡∏Å‡∏à‡∏∏‡∏î Threshold ‡∏ó‡∏µ‡πà‡∏î‡∏µ‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î ‡πÅ‡∏ó‡∏ô‡∏ó‡∏µ‡πà‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô 0.5
optimal_preds = (all_probs >= 0.5).astype(int)
tn, fp, fn, tp = confusion_matrix(all_targets, optimal_preds).ravel()

sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)

print(" Performance Metrics:")
print(f" - AUROC:             {auroc_val:.4f}")
print(f" - AUPRC:             {auprc_val:.4f}")
print(f" - Sensitivity (TPR): {sensitivity:.4f}  (TP:{tp}, FN:{fn})")
print(f" - Specificity (TNR): {specificity:.4f}  (TN:{tn}, FP:{fp})")

# 5. ‡∏ß‡∏≤‡∏î‡∏Å‡∏£‡∏≤‡∏ü ROC Curve ‡πÅ‡∏•‡∏∞ Confusion Matrix ‡πÉ‡∏´‡πâ‡∏≠‡∏¢‡∏π‡πà‡∏Ñ‡∏π‡πà‡∏Å‡∏±‡∏ô
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# ---- ‡∏Å‡∏£‡∏≤‡∏ü‡∏ó‡∏µ‡πà 1: ROC Curve ----
axes[0].plot(fpr, tpr, color='darkorange', lw=2.5, label=f'AUROC = {auroc_val:.4f}')
axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
# ‡∏°‡∏≤‡∏£‡πå‡∏Ñ‡∏à‡∏∏‡∏î Threshold ‡∏ó‡∏µ‡πà‡∏î‡∏µ‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î
axes[0].set_xlim([0.0, 1.0])
axes[0].set_ylim([0.0, 1.05])
axes[0].set_xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
axes[0].set_ylabel('True Positive Rate (Sensitivity)', fontsize=12)
axes[0].set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14)
axes[0].legend(loc="lower right", fontsize=11)
axes[0].grid(True, linestyle='--', alpha=0.6)

# ---- ‡∏Å‡∏£‡∏≤‡∏ü‡∏ó‡∏µ‡πà 2: Confusion Matrix ----
cm_optimal = confusion_matrix(all_targets, optimal_preds)
sns.heatmap(cm_optimal, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-PTB (0)', 'PTB (1)'], 
            yticklabels=['Non-PTB (0)', 'PTB (1)'],
            annot_kws={"size": 14}, ax=axes[1])
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_title(f'Confusion Matrix (@ Threshold {0.5:.4f})', fontsize=14)

plt.tight_layout()
plt.show()