In [16]:
from IPython.lib.display import Audio as AudioDisplay
import os
import torch
import numpy as np
import random
import librosa
from pathlib import Path
from tqdm import tqdm

from datasets import load_from_disk
from transformers import (
    ASTFeatureExtractor,
    ASTConfig,
    ASTForAudioClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report,
)
# Set style
plt.style.use('ggplot')
sns.set_palette("husl")

print("All imports successful")


All imports successful


## 2. Configuration


In [7]:
# ============================
# CONFIGURATION
# ============================

# Paths
DATASET_PATH = "../../dataset/ds_3_raw_chunked.hf"
OUTPUT_DIR = "./runs/ast_drone_detection"
MODEL_SAVE_PATH = "./best_ast_drone_model.pt"

# Model
PRETRAINED_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593"

# Training hyperparameters
SEED = 42
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
WARMUP_RATIO = 0.1

# System
NUM_WORKERS = 24
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Output directory: {OUTPUT_DIR}")

Device: cuda
Output directory: ./runs/ast_drone_detection


## 3. Set Random Seeds


In [8]:
def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)
print(f"Random seed set to {SEED}")


Random seed set to 42


## 4. Load Dataset


In [28]:
# Load the preprocessed dataset
if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}")

dataset = load_from_disk(DATASET_PATH)

print("Dataset loaded successfully!")
print(f"\nDataset structure:")
labels = dataset["train"].features["label"]
print(f"\nDataset splits:")
for split in dataset.keys():
    print(f"\t{split}: {len(dataset[split])} samples")

print(f"\nFeatures:")
print(dataset["train"].features)

# Check a sample
sample = dataset["train"][0]
print(f"\nSample data:")
print(f"\tAudio shape: {np.array(sample['audio']).shape}")
print(f"\tLabel: {sample['label']} ({CLASS_NAMES[sample['label']]})")

AudioDisplay(sample['audio'], rate=16000)


Dataset loaded successfully!

Dataset structure:

Dataset splits:
	train: 617500 samples
	val: 78726 samples
	test: 78483 samples

Features:
{'audio': List(Value('float32')), 'label': ClassLabel(names=['other', 'drone'])}

Sample data:
	Audio shape: (8000,)
	Label: 0 (other)


In [29]:
# Count the number of items per class
from collections import Counter
label_counts = Counter(dataset["train"]["label"])
print(label_counts)

Counter({0: 441442, 1: 176058})


## 5. Initialize Feature Extractor and Model


In [6]:
# Load feature extractor
print(f"Loading feature extractor from {PRETRAINED_MODEL}...")
feature_extractor = ASTFeatureExtractor.from_pretrained(PRETRAINED_MODEL)

SAMPLING_RATE = feature_extractor.sampling_rate
MODEL_INPUT_NAME = feature_extractor.model_input_names[0]

print(f"✅ Feature extractor loaded")
print(f"  Sampling rate: {SAMPLING_RATE} Hz")
print(f"  Model input name: {MODEL_INPUT_NAME}")
print(f"  Max length: {feature_extractor.max_length}")


Loading feature extractor from MIT/ast-finetuned-audioset-10-10-0.4593...
✅ Feature extractor loaded
  Sampling rate: 16000 Hz
  Model input name: input_values
  Max length: 1024


In [6]:
# Load and configure model
print(f"\nLoading model from {PRETRAINED_MODEL}...")

config = ASTConfig.from_pretrained(PRETRAINED_MODEL)
config.num_labels = NUM_LABELS
config.label2id = LABEL2ID
config.id2label = ID2LABEL

model = ASTForAudioClassification.from_pretrained(
    PRETRAINED_MODEL,
    config=config,
    ignore_mismatched_sizes=True  # Replaces the classification head
)

model.to(DEVICE)

print(f"✅ Model loaded and moved to {DEVICE}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")



Loading model from MIT/ast-finetuned-audioset-10-10-0.4593...


Loading weights: 100%|██████████| 203/203 [00:00<00:00, 1194.61it/s, Materializing param=audio_spectrogram_transformer.embeddings.cls_token]                            
ASTForAudioClassification LOAD REPORT from: MIT/ast-finetuned-audioset-10-10-0.4593
Key                     | Status   |                                                                                       
------------------------+----------+---------------------------------------------------------------------------------------
classifier.dense.bias   | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527]) vs model:torch.Size([2])          
classifier.dense.weight | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527, 768]) vs model:torch.Size([2, 768])
classifier.dense.weight | MISC     | 'Linear' object has no attribute 'param_name'
Error when processing parameter cl      
classifier.dense.bias   | MISC     | 'Linear' object has no attribute 'param_name'
Error when processing parameter cl      

No

✅ Model loaded and moved to cuda
  Total parameters: 86,190,338
  Trainable parameters: 86,190,338


## 6. Data Preprocessing

We need to convert spectrograms back to waveforms for AST processing.


In [7]:
def spectrogram_to_waveform(spectrogram, sr=16000, n_fft=2048, hop_length=512):
    """
    Convert a magnitude spectrogram back to a waveform using Griffin-Lim.
    
    Args:
        spectrogram: 2D array (freq_bins, time_frames) in dB scale
        sr: sampling rate
        n_fft: FFT window size
        hop_length: hop length for STFT
    
    Returns:
        waveform: 1D array
    """
    # Convert from dB to linear magnitude
    magnitude = librosa.db_to_amplitude(spectrogram)
    
    # Use Griffin-Lim to reconstruct phase and waveform
    waveform = librosa.griffinlim(
        magnitude,
        n_iter=32,
        hop_length=hop_length,
        n_fft=n_fft
    )
    
    return waveform

def preprocess_function(batch):
    print(batch)
    """
    Preprocess batch of spectrograms for AST model.
    Converts spectrograms to waveforms and extracts features.
    """
    waveforms = []
    
    for spec in batch["audio"]:
        # Convert list to numpy array
        spec_array = np.array(spec, dtype=np.float32)
        
        # Convert spectrogram to waveform
        waveform = spectrogram_to_waveform(spec_array, sr=SAMPLING_RATE)
        
        waveforms.append(waveform)
    
    # Extract features using AST feature extractor
    inputs = feature_extractor(
        waveforms,
        sampling_rate=SAMPLING_RATE,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=SAMPLING_RATE * 10  # 10 seconds max
    )
    
    return {
        MODEL_INPUT_NAME: inputs[MODEL_INPUT_NAME],
        "labels": batch["label"]
    }

print("✅ Preprocessing function defined")


✅ Preprocessing function defined


In [8]:
# Use a smaller subset for faster iteration during development
USE_SUBSET = True  # Set to False for full dataset
SUBSET_SIZE = 5000 if USE_SUBSET else None

if USE_SUBSET:
    print(f"⚠️  Using subset of {SUBSET_SIZE} samples for faster development")
    dataset["train"] = dataset["train"].select(range(min(SUBSET_SIZE, len(dataset["train"]))))
    dataset["val"] = dataset["val"].select(range(min(1000, len(dataset["val"]))))
    dataset["test"] = dataset["test"].select(range(min(1000, len(dataset["test"]))))

# NO set_transform! Preprocessing will be done in the collator
print("\n✅ Dataset subsets selected")
print(f"\nFinal dataset sizes:")
for split in dataset.keys():
    print(f"  {split}: {len(dataset[split])} samples")

⚠️  Using subset of 5000 samples for faster development

✅ Dataset subsets selected

Final dataset sizes:
  train: 5000 samples
  val: 1000 samples
  test: 1000 samples


## 7. Define Data Collator


In [9]:
from dataclasses import dataclass
from typing import Dict, List, Union

@dataclass
class ASTDataCollatorWithPreprocessing:
    """Custom data collator that preprocesses spectrograms on-the-fly."""
    
    feature_extractor: ASTFeatureExtractor
    
    def __call__(self, features: List[Dict[str, any]]) -> Dict[str, torch.Tensor]:
        # Extract raw spectrograms and labels
        spectrograms = [f["audio"] for f in features]
        labels = [f["label"] for f in features]
        
        # Convert spectrograms to waveforms using Griffin-Lim
        waveforms = []
        for spec in spectrograms:
            # Convert to numpy array
            spec_array = np.array(spec, dtype=np.float32)
            
            # Convert from dB to linear magnitude
            magnitude = librosa.db_to_amplitude(spec_array)
            
            # Use Griffin-Lim to reconstruct waveform
            waveform = librosa.griffinlim(
                magnitude,
                n_iter=32,
                hop_length=512,
                n_fft=2048
            )
            
            waveforms.append(waveform)
        
        # Extract AST features
        inputs = self.feature_extractor(
            waveforms,
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=SAMPLING_RATE * 10  # 10 seconds max
        )
        
        # Prepare batch
        batch = {
            MODEL_INPUT_NAME: inputs[MODEL_INPUT_NAME],
            "labels": torch.tensor(labels, dtype=torch.long)
        }
        
        return batch

# Create the new collator
collator = ASTDataCollatorWithPreprocessing(feature_extractor=feature_extractor)
print("✅ Custom data collator with preprocessing defined")

✅ Custom data collator with preprocessing defined


## 8. Define Metrics


In [10]:
# Load metrics
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    """Compute metrics for evaluation."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    
    metrics = {}
    metrics.update(accuracy_metric.compute(predictions=predictions, references=labels))
    metrics.update(precision_metric.compute(predictions=predictions, references=labels, average="binary"))
    metrics.update(recall_metric.compute(predictions=predictions, references=labels, average="binary"))
    metrics.update(f1_metric.compute(predictions=predictions, references=labels, average="binary"))
    
    return metrics

print("✅ Metrics defined")


✅ Metrics defined


## 9. Configure Training


In [11]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # Training
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    
    # Optimization
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    
    # Evaluation
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    
    # Logging
    logging_strategy="steps",
    logging_steps=50,
    report_to="none",  # Disable tensorboard (install tensorboard if you want logging)
    
    # System
    dataloader_num_workers=NUM_WORKERS,
    fp16=torch.cuda.is_available(),
    dataloader_pin_memory=True if NUM_WORKERS > 0 else False,
    
    # Saving
    save_total_limit=3,
    
    # Seed
    seed=SEED,
)

print("✅ Training arguments configured")
print(f"\nTraining configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"  Warmup ratio: {WARMUP_RATIO}")
print(f"  FP16: {training_args.fp16}")
print(f"  Workers: {NUM_WORKERS}")


warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


✅ Training arguments configured

Training configuration:
  Epochs: 10
  Batch size: 32
  Learning rate: 5e-05
  Weight decay: 0.0001
  Warmup ratio: 0.1
  FP16: True
  Workers: 24


## 10. Initialize Trainer


In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"],
    data_collator=collator,
    compute_metrics=compute_metrics,
)

print("✅ Trainer initialized")


✅ Trainer initialized


## 11. Train Model


In [13]:
print("Starting training...\n")
print("=" * 80)

train_result = trainer.train()

print("\n" + "=" * 80)
print("Training complete!\n")

# Save final model
trainer.save_model(OUTPUT_DIR + "/final_model")
feature_extractor.save_pretrained(OUTPUT_DIR + "/final_model")

# Save training metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

print(f"✅ Model saved to {OUTPUT_DIR}/final_model")


Starting training...



KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ~~~~~~~~~~~~~~~^^^^^^
  File "/tmp/ipykernel_74369/2595779619.py", line 12, in __call__
    spectrograms = [f["audio"] for f in features]
                    ~^^^^^^^^^
KeyError: 'audio'


## 12. Evaluate on Validation Set


In [None]:
print("Evaluating on validation set...\n")

eval_metrics = trainer.evaluate(eval_dataset=dataset["val"])

trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)

print("\nValidation Metrics:")
print("=" * 50)
for key, value in eval_metrics.items():
    if key.startswith("eval_"):
        metric_name = key.replace("eval_", "")
        print(f"{metric_name:20s}: {value:.4f}")


## 13. Test Set Evaluation


In [None]:
print("Evaluating on test set...\n")

# Get predictions
predictions_output = trainer.predict(dataset["test"])
predictions = np.argmax(predictions_output.predictions, axis=1)
labels = predictions_output.label_ids

# Compute metrics
test_accuracy = accuracy_score(labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
    labels, predictions, average="binary"
)

print("\nTest Set Metrics:")
print("=" * 50)
print(f"Accuracy:  {test_accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")


## 14. Detailed Classification Report


In [None]:
print("\nDetailed Classification Report:")
print("=" * 80)
print(classification_report(
    labels,
    predictions,
    target_names=CLASS_NAMES,
    digits=4
))


## 15. Confusion Matrix


In [None]:
# Compute confusion matrix
cm = confusion_matrix(labels, predictions)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=CLASS_NAMES,
    yticklabels=CLASS_NAMES,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - Test Set', fontsize=16, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Confusion matrix saved to {OUTPUT_DIR}/confusion_matrix.png")


## 16. Training History Visualization


In [None]:
# Load training history from logs
import json

log_history = trainer.state.log_history

# Extract metrics
train_loss = []
eval_loss = []
eval_accuracy = []
eval_f1 = []
epochs = []

for entry in log_history:
    if 'loss' in entry and 'epoch' in entry:
        train_loss.append(entry['loss'])
    if 'eval_loss' in entry:
        eval_loss.append(entry['eval_loss'])
        eval_accuracy.append(entry.get('eval_accuracy', 0))
        eval_f1.append(entry.get('eval_f1', 0))
        epochs.append(entry['epoch'])

# Plot
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# Loss plot
if eval_loss:
    axes[0].plot(epochs, eval_loss, marker='o', linewidth=2, label='Validation Loss')
    axes[0].set_title('Loss vs. Epochs', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)

# Metrics plot
if eval_accuracy and eval_f1:
    axes[1].plot(epochs, eval_accuracy, marker='o', linewidth=2, label='Accuracy')
    axes[1].plot(epochs, eval_f1, marker='s', linewidth=2, label='F1 Score')
    axes[1].set_title('Metrics vs. Epochs', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Score', fontsize=12)
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_history.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Training history saved to {OUTPUT_DIR}/training_history.png")


## 17. Per-Class Analysis


In [None]:
# Per-class metrics
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
    labels, predictions, average=None
)

print("\nPer-Class Analysis:")
print("=" * 80)
print(f"{'Class':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 80)
for i, class_name in enumerate(CLASS_NAMES):
    print(f"{class_name:<15} {precision_per_class[i]:<12.4f} {recall_per_class[i]:<12.4f} "
          f"{f1_per_class[i]:<12.4f} {support[i]:<10}")

# Visualize per-class metrics
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(CLASS_NAMES))
width = 0.25

ax.bar(x - width, precision_per_class, width, label='Precision', alpha=0.8)
ax.bar(x, recall_per_class, width, label='Recall', alpha=0.8)
ax.bar(x + width, f1_per_class, width, label='F1-Score', alpha=0.8)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Metrics', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(CLASS_NAMES)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0, 1.1])

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/per_class_metrics.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✅ Per-class metrics saved to {OUTPUT_DIR}/per_class_metrics.png")


## 18. Prediction Confidence Analysis


In [None]:
# Get prediction probabilities
import torch.nn.functional as F

logits = predictions_output.predictions
probabilities = F.softmax(torch.tensor(logits), dim=1).numpy()

# Get max probability (confidence) for each prediction
confidences = np.max(probabilities, axis=1)

# Separate correct and incorrect predictions
correct_mask = predictions == labels
correct_confidences = confidences[correct_mask]
incorrect_confidences = confidences[~correct_mask]

print(f"\nPrediction Confidence Analysis:")
print("=" * 80)
print(f"Correct predictions:   {len(correct_confidences)} (avg confidence: {correct_confidences.mean():.4f})")
print(f"Incorrect predictions: {len(incorrect_confidences)} (avg confidence: {incorrect_confidences.mean() if len(incorrect_confidences) > 0 else 0:.4f})")

# Plot confidence distributions
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Histogram
if len(correct_confidences) > 0:
    axes[0].hist(correct_confidences, bins=30, alpha=0.7, label='Correct', color='green')
if len(incorrect_confidences) > 0:
    axes[0].hist(incorrect_confidences, bins=30, alpha=0.7, label='Incorrect', color='red')
axes[0].set_xlabel('Confidence', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Confidence Distribution', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Box plot
box_data = [correct_confidences]
box_labels = ['Correct']
if len(incorrect_confidences) > 0:
    box_data.append(incorrect_confidences)
    box_labels.append('Incorrect')

axes[1].boxplot(
    box_data,
    labels=box_labels,
    patch_artist=True,
    boxprops=dict(facecolor='lightblue', alpha=0.7)
)
axes[1].set_ylabel('Confidence', fontsize=12)
axes[1].set_title('Confidence Box Plot', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confidence_analysis.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✅ Confidence analysis saved to {OUTPUT_DIR}/confidence_analysis.png")


## 19. Save Final Model


In [None]:
# Save model state dict
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"✅ Model state dict saved to {MODEL_SAVE_PATH}")

# Save complete model info
model_info = {
    'model_name': PRETRAINED_MODEL,
    'num_labels': NUM_LABELS,
    'class_names': CLASS_NAMES,
    'label2id': LABEL2ID,
    'id2label': ID2LABEL,
    'test_accuracy': float(test_accuracy),
    'test_precision': float(precision),
    'test_recall': float(recall),
    'test_f1': float(f1),
}

import json
with open(f"{OUTPUT_DIR}/model_info.json", 'w') as f:
    json.dump(model_info, f, indent=2)

print(f"✅ Model info saved to {OUTPUT_DIR}/model_info.json")


## 20. Summary


In [None]:
print("\n" + "=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"\nModel: {PRETRAINED_MODEL}")
print(f"Task: Binary Classification (Drone Detection)")
print(f"Classes: {CLASS_NAMES}")
print(f"\nDataset:")
print(f"  Train samples: {len(dataset['train'])}")
print(f"  Val samples:   {len(dataset['val'])}")
print(f"  Test samples:  {len(dataset['test'])}")
print(f"\nTraining Configuration:")
print(f"  Epochs:        {NUM_EPOCHS}")
print(f"  Batch size:    {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay:  {WEIGHT_DECAY}")
print(f"\nFinal Test Set Performance:")
print(f"  Accuracy:  {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1 Score:  {f1:.4f}")
print(f"\nOutput Files:")
print(f"  Model directory:     {OUTPUT_DIR}/final_model")
print(f"  Model state dict:    {MODEL_SAVE_PATH}")
print(f"  Confusion matrix:    {OUTPUT_DIR}/confusion_matrix.png")
print(f"  Training history:    {OUTPUT_DIR}/training_history.png")
print(f"  Per-class metrics:   {OUTPUT_DIR}/per_class_metrics.png")
print(f"  Confidence analysis: {OUTPUT_DIR}/confidence_analysis.png")
print(f"  Model info:          {OUTPUT_DIR}/model_info.json")
print("\n" + "=" * 80)
print("✅ TRAINING COMPLETE!")
print("=" * 80)
