In [1]:
import sys
import os

sys.path.append('./src')
os.environ['TORCH_HOME'] = '../pretrained_models'

from models import ASTModel
import dataloader
import torch
import matplotlib.pyplot as plt
import numpy as np
import json
import random
from collections import defaultdict
import argparse
from traintest import train, validate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


 AST DATASET PREPARATION SUMMARY (3-Way Patient-Level Split)
 TRAIN SET (Total Patients: 9)
    PTB ID: ['009', '011', '015']
    Non-PTB ID: ['001', '002', '003', '007', '013', '014']
    Total Audio Files: 294 (PTB: 42 files | Non-PTB: 252 files)
 VALIDATION SET (Total Patients: 3)
    PTB ID: ['008']
    Non-PTB ID: ['004', '005']
    Total Audio Files: 115 (PTB: 42 files | Non-PTB: 73 files)
 TEST SET (Total Patients: 3)
    PTB ID: ['006', '012', '016']
    Non-PTB ID: []
    Total Audio Files: 103 (PTB: 103 files | Non-PTB: 0 files)


In [None]:
print("üöÄ STARTING AST-P SINGLE RUN (80/20 SPLIT) WITH VISUALIZATION")
# 1. ‡∏Å‡∏≥‡∏´‡∏ô‡∏î‡πÇ‡∏ü‡∏•‡πÄ‡∏î‡∏≠‡∏£‡πå‡πÄ‡∏Å‡πá‡∏ö‡∏ú‡∏•‡∏•‡∏±‡∏û‡∏ò‡πå
exp_dir = './exp/tb_ast_p_single_run'
os.makedirs(f'{exp_dir}/models', exist_ok=True)

MEAN_NORM = -4.27
STD_NORM = 4.57

# 2. ‡∏ï‡∏±‡πâ‡∏á‡∏Ñ‡πà‡∏≤ Data Config ‡∏ï‡∏≤‡∏° Speechcommands V2 Recipe
# target_length=128, freqm=48, timem=48, mixup=0.6, noise=True
train_audio_conf = {'num_mel_bins': 128, 'target_length': 100, 'freqm': 0, 'timem': 0, 'mixup': 0.0, 'dataset': 'audioset', 'mode': 'train', 'mean': MEAN_NORM, 'std': STD_NORM, 'noise': True, 'skip_norm': False}
eval_audio_conf  = {'num_mel_bins': 128, 'target_length': 100, 'freqm': 0,  'timem': 0,  'mixup': 0.0, 'dataset': 'audioset', 'mode': 'evaluation', 'mean': MEAN_NORM, 'std': STD_NORM, 'noise': False, 'skip_norm': False}

BATCH_SIZE = 8

# 3. ‡∏™‡∏£‡πâ‡∏≤‡∏á DataLoader
# ‡πÉ‡∏ä‡πâ val_data.json ‡∏™‡∏≥‡∏´‡∏£‡∏±‡∏ö validation ‡∏£‡∏∞‡∏´‡∏ß‡πà‡∏≤‡∏á training (‡∏ï‡∏≤‡∏° recipe ‡∏ó‡∏µ‡πà‡πÅ‡∏¢‡∏Å val ‡∏≠‡∏≠‡∏Å‡∏à‡∏≤‡∏Å test)
train_loader = torch.utils.data.DataLoader(
    dataloader.AudiosetDataset('train_data.json', label_csv='class_labels_indices.csv', audio_conf=train_audio_conf),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
)
eval_loader = torch.utils.data.DataLoader(
    dataloader.AudiosetDataset('val_data.json', label_csv='class_labels_indices.csv', audio_conf=eval_audio_conf),
    batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=0, pin_memory=True
)

# ‡πÅ‡∏™‡∏î‡∏á‡∏†‡∏≤‡∏û‡∏ï‡∏±‡∏ß‡∏≠‡∏¢‡πà‡∏≤‡∏á Spectrogram
# ‡∏î‡∏∂‡∏á‡∏Ç‡πâ‡∏≠‡∏°‡∏π‡∏•‡∏°‡∏≤ 1 Batch ‡∏à‡∏≤‡∏Å train_loader ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏î‡∏π‡∏†‡∏≤‡∏û
sample_inputs, sample_labels = next(iter(train_loader))
# sample_inputs ‡∏à‡∏∞‡∏°‡∏µ‡∏Ç‡∏ô‡∏≤‡∏î [batch_size, temporal_frame_num, frequency_bin_num]
# ‡πÄ‡∏£‡∏≤‡∏î‡∏∂‡∏á index 0 ‡∏Ç‡∏≠‡∏á Batch ‡∏°‡∏≤‡∏û‡∏•‡πá‡∏≠‡∏ï
spec_data = sample_inputs[0].numpy()  
label_data = torch.argmax(sample_labels[0]).item()
class_name = "PTB" if label_data == 1 else "Non-PTB"

plt.figure(figsize=(10, 4))
# ‡∏û‡∏•‡πá‡∏≠‡∏ï‡πÇ‡∏î‡∏¢‡∏™‡∏•‡∏±‡∏ö‡πÅ‡∏Å‡∏ô (Transpose) ‡πÉ‡∏´‡πâ‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ñ‡∏µ‡πà (128) ‡∏≠‡∏¢‡∏π‡πà‡πÅ‡∏Å‡∏ô Y ‡πÅ‡∏•‡∏∞‡πÄ‡∏ß‡∏•‡∏≤ (128) ‡∏≠‡∏¢‡∏π‡πà‡πÅ‡∏Å‡∏ô X
plt.imshow(spec_data.T, aspect='auto', origin='lower', cmap='viridis')
plt.title(f"Example Log Mel Filterbank Spectrogram (Class: {class_name})")
plt.ylabel("Frequency Bins (128)")
plt.xlabel("Time Frames (128)")
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
plt.show()
print("‚òùÔ∏è ‡∏†‡∏≤‡∏û‡∏î‡πâ‡∏≤‡∏ô‡∏ö‡∏ô‡∏Ñ‡∏∑‡∏≠‡∏ü‡∏µ‡πÄ‡∏à‡∏≠‡∏£‡πå Spectrogram 128 ‡∏°‡∏¥‡∏ï‡∏¥‡∏ó‡∏µ‡πà‡∏ú‡πà‡∏≤‡∏ô‡∏Å‡∏≤‡∏£‡∏ó‡∏≥ Normalization ‡πÅ‡∏•‡πâ‡∏ß ‡∏ã‡∏∂‡πà‡∏á‡∏à‡∏∞‡∏ñ‡∏π‡∏Å‡∏™‡πà‡∏á‡πÄ‡∏Ç‡πâ‡∏≤ AST Model")
# ==========================================

# 4. ‡∏™‡∏£‡πâ‡∏≤‡∏á AST-P Model ‡∏ï‡∏≤‡∏° Speechcommands V2 Recipe
# audioset_pretrain=False (Speechcommands ‡πÉ‡∏ä‡πâ ImageNet pretrain ‡πÄ‡∏ó‡πà‡∏≤‡∏ô‡∏±‡πâ‡∏ô)
# input_tdim=128 ‡πÉ‡∏´‡πâ‡∏ï‡∏£‡∏á‡∏Å‡∏±‡∏ö target_length ‡πÉ‡∏ô audio_conf
ast_model = ASTModel(
    label_dim=2, 
    fstride=10, 
    tstride=10, 
    input_fdim=128, 
    input_tdim=128, 
    imagenet_pretrain=True, 
    audioset_pretrain=False,   # Speechcommands V2: audioset_pretrain=False
    model_size='base384'
)

# 5. ‡∏ï‡∏±‡πâ‡∏á‡∏Ñ‡πà‡∏≤ Training Arguments ‡∏ï‡∏≤‡∏° Speechcommands V2 Recipe
args = argparse.Namespace(
    exp_dir=exp_dir,
    dataset='speechcommands',
    n_class=2,
    lr=2.5e-4,            # Speechcommands V2: lr=2.5e-4
    n_epochs=30,
    batch_size=BATCH_SIZE,
    n_print_steps=10,
    save_model=True,
    loss='BCE',           # Speechcommands V2: loss=BCE
    metrics='acc',        # Speechcommands V2: metrics=acc
    lrscheduler_start=5,  # Speechcommands V2: lrscheduler_start=5
    lrscheduler_step=1,   # Speechcommands V2: lrscheduler_step=1
    lrscheduler_decay=0.85,  # Speechcommands V2: lrscheduler_decay=0.85
    warmup=False,         # Speechcommands V2: warmup=False
    wa=False,             # Speechcommands V2 ‡πÑ‡∏°‡πà‡πÉ‡∏ä‡πâ Weighted Averaging
    wa_start=1,
    wa_end=30
)

# 6. ‡πÄ‡∏£‡∏¥‡πà‡∏°‡∏Å‡∏£‡∏∞‡∏ö‡∏ß‡∏ô‡∏Å‡∏≤‡∏£ Train
print("\nStarting Model Training...")
train(ast_model, train_loader, eval_loader, args)

# 7. ‡πÇ‡∏´‡∏•‡∏î Best Model ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏õ‡∏£‡∏∞‡πÄ‡∏°‡∏¥‡∏ô‡∏ú‡∏•
best_model_path = f'{args.exp_dir}/models/best_audio_model.pth'
best_model = ASTModel(label_dim=2, fstride=10, tstride=10, input_fdim=128, input_tdim=128, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
best_model = torch.nn.DataParallel(best_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = best_model.to(device)

best_model.load_state_dict(torch.load(best_model_path, map_location=device))

stats, eval_loss = validate(best_model, eval_loader, args, epoch='best')


In [None]:
# ==========================================
# üìä [NEW] ‡∏ß‡∏≤‡∏î Confusion Matrix ‡∏à‡∏≤‡∏Å‡∏ä‡∏∏‡∏î Eval
# ==========================================
import seaborn as sns
from sklearn.metrics import confusion_matrix


print("\nüìà Generating Confusion Matrix...")
all_preds = []
all_targets = []

best_model.eval() # ‡πÄ‡∏Ç‡πâ‡∏≤‡∏™‡∏π‡πà‡πÇ‡∏´‡∏°‡∏î‡∏ó‡∏î‡∏™‡∏≠‡∏ö
with torch.no_grad():
    for audio, labels in eval_loader:
        audio = audio.to(device)
        # AST Output ‡πÄ‡∏õ‡πá‡∏ô Raw Logits (‡πÑ‡∏°‡πà‡∏°‡∏µ Sigmoid/Softmax)
        logits = best_model(audio) 
        # ‡πÉ‡∏ä‡πâ Argmax ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡∏´‡∏≤‡∏Ñ‡∏•‡∏≤‡∏™‡∏ó‡∏µ‡πà‡∏°‡∏µ‡∏Ñ‡πà‡∏≤ Logit ‡∏™‡∏π‡∏á‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        
        all_preds.extend(preds)
        
        # ‚úÖ ‡πÅ‡∏õ‡∏•‡∏á labels ‡πÉ‡∏´‡πâ‡πÄ‡∏õ‡πá‡∏ô‡∏ï‡∏±‡∏ß‡πÄ‡∏•‡∏Ç‡πÄ‡∏î‡∏µ‡πà‡∏¢‡∏ß (0 ‡∏´‡∏£‡∏∑‡∏≠ 1) ‡∏Å‡πà‡∏≠‡∏ô‡πÄ‡∏Å‡πá‡∏ö‡∏•‡∏á‡∏•‡∏¥‡∏™‡∏ï‡πå
        all_targets.extend(torch.argmax(labels, dim=1).numpy())

# ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Å‡∏£‡∏≤‡∏ü Confusion Matrix
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-PTB (0)', 'PTB (1)'], 
            yticklabels=['Non-PTB (0)', 'PTB (1)'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('AST-P Confusion Matrix (Val Set)')
plt.tight_layout()
plt.show()


In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    roc_curve, auc, average_precision_score, confusion_matrix
)

# ==========================================
# 1. ‡∏´‡∏≤ Best Epoch ‡∏à‡∏≤‡∏Å‡πÑ‡∏ü‡∏•‡πå result.csv ‡∏Ç‡∏≠‡∏á AST
# ==========================================
# ‡πÇ‡∏Ñ‡∏£‡∏á‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ñ‡∏≠‡∏•‡∏±‡∏°‡∏ô‡πå‡∏°‡∏≤‡∏ï‡∏£‡∏ê‡∏≤‡∏ô‡∏Ç‡∏≠‡∏á AST: [mAP, mAUC, precision, recall, d_prime, train_loss, valid_loss, cum_mAP, cum_mAUC, lr]
csv_path = './exp/tb_ast_p_single_run/result.csv'
df_results = pd.read_csv(csv_path, header=None)

# ‡∏Ñ‡∏≠‡∏•‡∏±‡∏°‡∏ô‡πå index 1 ‡∏Ñ‡∏∑‡∏≠ mAUC ‡∏Ç‡∏≠‡∏á‡πÅ‡∏ï‡πà‡∏•‡∏∞ Epoch
best_epoch = df_results[1].idxmax() + 1
best_mauc = df_results[1].max()
print(f"üåü Best Epoch: {best_epoch} (Validation AUC from log: {best_mauc:.4f})\n")

# ==========================================
# 2. ‡πÄ‡∏Å‡πá‡∏ö‡∏ú‡∏•‡∏•‡∏±‡∏û‡∏ò‡πå‡πÅ‡∏ö‡∏ö‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô (Probabilities) ‡∏à‡∏≤‡∏Å Best Model
# ==========================================
print(" Evaluating Best Model on Evaluation Set...")
all_probs = []
all_targets = []

best_model.eval() # ‡πÄ‡∏Ç‡πâ‡∏≤‡∏™‡∏π‡πà‡πÇ‡∏´‡∏°‡∏î‡∏ó‡∏î‡∏™‡∏≠‡∏ö
with torch.no_grad():
    for audio, labels in eval_loader:
        audio = audio.to(device)
        logits = best_model(audio) 
        
        # AST Output ‡πÄ‡∏õ‡πá‡∏ô Raw Logits (‡∏°‡∏µ 2 ‡∏Ñ‡πà‡∏≤ ‡∏ï‡∏≤‡∏°‡∏à‡∏≥‡∏ô‡∏ß‡∏ô n_class=2)
        # ‡πÄ‡∏£‡∏≤‡πÉ‡∏ä‡πâ Softmax ‡πÄ‡∏û‡∏∑‡πà‡∏≠‡πÅ‡∏õ‡∏•‡∏á‡πÄ‡∏õ‡πá‡∏ô‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô (0-1) ‡πÅ‡∏•‡∏∞‡πÄ‡∏•‡∏∑‡∏≠‡∏Å‡πÄ‡∏â‡∏û‡∏≤‡∏∞‡∏Ñ‡∏ß‡∏≤‡∏°‡∏ô‡πà‡∏≤‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô‡∏Ç‡∏≠‡∏á‡∏Ñ‡∏•‡∏≤‡∏™ 1 (PTB)
        probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
        
        # ‡∏î‡∏∂‡∏á Label ‡∏Ñ‡∏•‡∏≤‡∏™‡∏à‡∏£‡∏¥‡∏á
        targets = torch.argmax(labels, dim=1).cpu().numpy()
        
        all_probs.extend(probs)
        all_targets.extend(targets)

all_probs = np.array(all_probs)
all_targets = np.array(all_targets)

# ==========================================
# 3. ‡∏Ñ‡∏≥‡∏ô‡∏ß‡∏ì‡∏Ñ‡∏ß‡∏≤‡∏°‡πÅ‡∏°‡πà‡∏ô‡∏¢‡∏≥ AUROC, AUPRC ‡πÅ‡∏•‡∏∞‡∏´‡∏≤ Best Threshold
# ==========================================
fpr, tpr, roc_thresholds = roc_curve(all_targets, all_probs)
auroc_val = auc(fpr, tpr)
auprc_val = average_precision_score(all_targets, all_probs)


# ==========================================
# 4. ‡∏Ñ‡∏≥‡∏ô‡∏ß‡∏ì Sensitivity ‡πÅ‡∏•‡∏∞ Specificity ‡∏ó‡∏µ‡πà‡∏à‡∏∏‡∏î Best Threshold
# ==========================================
# ‡∏™‡∏£‡πâ‡∏≤‡∏á‡∏Ñ‡∏≥‡∏ó‡∏≥‡∏ô‡∏≤‡∏¢‡πÉ‡∏´‡∏°‡πà‡πÇ‡∏î‡∏¢‡∏≠‡∏¥‡∏á‡∏à‡∏≤‡∏Å‡∏à‡∏∏‡∏î Threshold ‡∏ó‡∏µ‡πà‡∏î‡∏µ‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î ‡πÅ‡∏ó‡∏ô‡∏ó‡∏µ‡πà‡∏à‡∏∞‡πÄ‡∏õ‡πá‡∏ô 0.5
optimal_preds = (all_probs >= 0.5).astype(int)
tn, fp, fn, tp = confusion_matrix(all_targets, optimal_preds).ravel()

sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)

print(" Performance Metrics:")
print(f" - AUROC:             {auroc_val:.4f}")
print(f" - AUPRC:             {auprc_val:.4f}")
print(f" - Sensitivity (TPR): {sensitivity:.4f}  (TP:{tp}, FN:{fn})")
print(f" - Specificity (TNR): {specificity:.4f}  (TN:{tn}, FP:{fp})")

# 5. ‡∏ß‡∏≤‡∏î‡∏Å‡∏£‡∏≤‡∏ü ROC Curve ‡πÅ‡∏•‡∏∞ Confusion Matrix ‡πÉ‡∏´‡πâ‡∏≠‡∏¢‡∏π‡πà‡∏Ñ‡∏π‡πà‡∏Å‡∏±‡∏ô
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# ---- ‡∏Å‡∏£‡∏≤‡∏ü‡∏ó‡∏µ‡πà 1: ROC Curve ----
axes[0].plot(fpr, tpr, color='darkorange', lw=2.5, label=f'AUROC = {auroc_val:.4f}')
axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
# ‡∏°‡∏≤‡∏£‡πå‡∏Ñ‡∏à‡∏∏‡∏î Threshold ‡∏ó‡∏µ‡πà‡∏î‡∏µ‡∏ó‡∏µ‡πà‡∏™‡∏∏‡∏î
axes[0].set_xlim([0.0, 1.0])
axes[0].set_ylim([0.0, 1.05])
axes[0].set_xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
axes[0].set_ylabel('True Positive Rate (Sensitivity)', fontsize=12)
axes[0].set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14)
axes[0].legend(loc="lower right", fontsize=11)
axes[0].grid(True, linestyle='--', alpha=0.6)

# ---- ‡∏Å‡∏£‡∏≤‡∏ü‡∏ó‡∏µ‡πà 2: Confusion Matrix ----
cm_optimal = confusion_matrix(all_targets, optimal_preds)
sns.heatmap(cm_optimal, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-PTB (0)', 'PTB (1)'], 
            yticklabels=['Non-PTB (0)', 'PTB (1)'],
            annot_kws={"size": 14}, ax=axes[1])
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_title(f'Confusion Matrix (@ Threshold {0.5:.4f})', fontsize=14)

plt.tight_layout()
plt.show()