# Audio Weld Defect Classification — Training Notebook

Instantiate the `AudioCNN` classifier, dataset, loss, optimizer and train
using the existing `run_training` loop.

In [None]:
import json
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

## 1. Load Config

In [None]:
with open("configs/audio_config.json") as f:
    cfg = json.load(f)

audio_cfg = cfg["audio"]
model_cfg = cfg["model"]
optim_cfg = cfg["optimizer"]
train_cfg = cfg["training"]
data_cfg = cfg["data"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(json.dumps(cfg, indent=2))

## 2. Dataset

In [None]:
from audio_processing import AudioDataset

full_dataset = AudioDataset(data_cfg["data_root"], cfg=audio_cfg, labeled=True)
num_classes = len(full_dataset.label_to_idx)

print(f"Total samples: {len(full_dataset)}")
print(f"Classes ({num_classes}): {full_dataset.label_to_idx}")

# Store label map for later use
cfg["label_map"] = full_dataset.idx_to_label

In [None]:
# Train / val split
val_size = int(len(full_dataset) * train_cfg["val_split"])
train_size = len(full_dataset) - val_size

generator = torch.Generator().manual_seed(train_cfg["seed"])
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)
print(f"Train: {train_size} | Val: {val_size}")

# Collate: AudioDataset returns dicts -> (inputs, targets) tuples
def train_collate_fn(batch):
    audios = torch.stack([item["audio"] for item in batch])   # (B, 1, n_mels, T)
    labels = torch.tensor([item["label"] for item in batch])  # (B,)
    return audios, labels

train_loader = DataLoader(
    train_dataset, batch_size=train_cfg["batch_size"], shuffle=True,
    num_workers=train_cfg["num_workers"], collate_fn=train_collate_fn,
)
val_loader = DataLoader(
    val_dataset, batch_size=train_cfg["batch_size"], shuffle=False,
    num_workers=train_cfg["num_workers"], collate_fn=train_collate_fn,
)

## 3. Model

In [None]:
from audio_model import AudioCNN

model = AudioCNN(num_classes=num_classes, dropout=model_cfg["dropout"])
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Loss & Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=optim_cfg["lr"],
    weight_decay=optim_cfg["weight_decay"],
)

## 5. Train

In [None]:
from run_train import run_training

checkpoint_dir = train_cfg["checkpoint_dir"]
os.makedirs(checkpoint_dir, exist_ok=True)

# Save config alongside checkpoints for reproducibility
with open(os.path.join(checkpoint_dir, "config.json"), "w") as f:
    json.dump(cfg, f, indent=2)

patience = train_cfg["patience"] if train_cfg["patience"] > 0 else None

results = run_training(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=train_cfg["num_epochs"],
    checkpoint_dir=checkpoint_dir,
    patience=patience,
    seed=train_cfg["seed"],
)

## 6. Training Curves

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.plot(results["train_losses"], label="Train")
plt.plot(results["val_losses"], label="Val")
plt.title("AudioCNN — CrossEntropy Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(checkpoint_dir, "audio_training_curve.png"), dpi=150)
plt.show()

print(f"Best epoch: {results['best_epoch']}")