In [None]:
import os, time, math
import sys
from pathlib import Path

# Use PyTorch + torchaudio for faster on-device spectrogram computation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T

import numpy as np
from sklearn.model_selection import train_test_split

sys.path.append('..')
# We will not use the old mel_spectrogram here; we compute mel on-device with torchaudio
from preprocess.wav_helper import trim_audio_to_np_float  # keep helper if needed


In [None]:

DATASET_DIR = '../../' + '.tstdata/dataset'
# Expecting exact 4s wavs at 48kHz (16-bit PCM)
TARGET_SR = 48000
SEGMENT_SECONDS = 4.0
NUM_SAMPLES = int(TARGET_SR * SEGMENT_SECONDS)

# SPECTROGRAM
N_MELS = 128
HOP_LENGTH = 320
N_FFT = 2048

FREQ_RANGE = (3500, 8000)

LABELS = sorted(os.listdir(os.path.join(DATASET_DIR, 'train')))


In [None]:

class TorchAudioDataset(Dataset):
    def __init__(self, root_dir, labels, sr=TARGET_SR, segment_seconds=SEGMENT_SECONDS, transform=None, device='cpu') -> None:
        self.root_dir = root_dir
        self.sr = sr
        self.segment_seconds = segment_seconds
        self.num_samples = int(sr * segment_seconds)
        self.transform = transform  # keep for compatibility but DO NOT call it inside worker processes
        self.device = device
        
        # collect files and labels
        self.samples = []
        self.label_names = labels
        self.label_to_idx = {n: i for i, n in enumerate(self.label_names)}
        for label in self.label_names:
            label_dir = os.path.join(self.root_dir, label)
            if not os.path.isdir(label_dir):
                continue
            for fname in os.listdir(label_dir):
                if fname.lower().endswith('.wav'):
                    self.samples.append((os.path.join(label_dir, fname), self.label_to_idx[label]))
        
    def __len__(self):
        return len(self.samples)
    
    def _load_wav(self, path):
        # torchaudio.load returns (waveform, sr) with waveform shape (channels, samples)
        wav, sr = torchaudio.load(path)
        if sr != self.sr:
            # resample if necessary (CPU side)
            resampler = T.Resample(orig_freq=sr, new_freq=self.sr)
            wav = resampler(wav)
        # convert to mono by averaging channels if needed
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        # ensure length exactly num_samples
        if wav.shape[1] > self.num_samples:
            wav = wav[:, :self.num_samples]
        elif wav.shape[1] < self.num_samples:
            pad_amount = self.num_samples - wav.shape[1]
            wav = F.pad(wav, (0, pad_amount))
        return wav
    
    def __getitem__(self, idx):
        path, label_idx = self.samples[idx]
        # load waveform and return it AS-IS (on CPU). Do not move to GPU or apply transform here — avoids CUDA use inside workers.
        wav = self._load_wav(path)  # shape: (1, num_samples), CPU tensor
        # return waveform and integer label; DataLoader will collate into batched tensors
        return wav, label_idx
    
# Create torchaudio transforms that will run on the training device (main process)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

mel_transform = nn.Sequential(
    T.MelSpectrogram(
        sample_rate=TARGET_SR, 
        n_fft=N_FFT, 
        hop_length=HOP_LENGTH, 
        n_mels=N_MELS, 
        center=True,
        power=2.0,
        f_min=FREQ_RANGE[0],
        f_max=FREQ_RANGE[1]
    ),
    T.AmplitudeToDB(stype='power', top_db=80.0),
)
mel_transform = mel_transform.to(DEVICE)


In [None]:
import matplotlib.pyplot as plt

# Utility to plot a mel-spectrogram (tensor or numpy array)
def plot_spectrogram(spec, sr=TARGET_SR, hop_length=HOP_LENGTH, title=None, cmap='viridis'):
    # spec: torch.Tensor or np.ndarray with shape (n_mels, time) or (1, n_mels, time)
    if isinstance(spec, torch.Tensor):
        spec = spec.detach().cpu().numpy()
    # squeeze channel dim if present
    spec = spec.squeeze()
    plt.figure(figsize=(10, 4))
    plt.imshow(spec, aspect='auto', origin='lower', cmap=cmap)
    plt.colorbar(format='%+2.0f dB')
    if title:
        plt.title(title)
    plt.ylabel('Mel bin')
    plt.xlabel('Frame')
    plt.tight_layout()
    plt.show()

In [None]:
# Simple CNN in PyTorch expecting input shape (batch, 1, N_MELS, T)
class SimpleCNN(nn.Module):
    def __init__(self, n_mels=N_MELS, n_classes=len(LABELS)) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2,2)
        # reduce spatial dims to 1x1 regardless of time axis length
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(128, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, n_classes)
    def forward(self, x):
        # x shape: (B, 1, n_mels, time)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        # global average pool to (B, 128, 1, 1)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # (B, 128)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Setup model, loss, optimizer
model = SimpleCNN(n_mels=N_MELS, n_classes=len(LABELS)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # original 1e-3


In [None]:
BATCH_SIZE = 16
# create datasets
train_dataset = TorchAudioDataset(DATASET_DIR + '/train', LABELS, sr=TARGET_SR, transform=None, device=DEVICE)
val_dataset   = TorchAudioDataset(DATASET_DIR + '/val', LABELS, sr=TARGET_SR, transform=None, device=DEVICE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, collate_fn=None)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=None)

# Simple training loop
EPOCHS = 30
checkpoint_path = '../../.tstdata/ckpt/torch_checkpoint.pth'

def batch_to_device_and_spec(batch_wavs, device):
    # batch_wavs: list or tensor of waveforms; expected shape after stacking: (B, 1, num_samples) on CPU
    if isinstance(batch_wavs, list):
        batch = torch.stack(batch_wavs, dim=0)
    else:
        batch = batch_wavs
    # move to device and ensure float32
    batch = batch.to(device=device, dtype=torch.float32)
    # compute mel spectrogram on device: MelSpectrogram expects (channels, samples) or (batch, channels, samples)
    # our mel_transform is nn.Sequential(MelSpectrogram, AmplitudeToDB) and expects (..., samples) shape
    # torchaudio transforms accept (batch, channel, time) as input
    specs = mel_transform(batch)  # shape: (B, n_mels, time) because MelSpectrogram returns (batch, n_mels, time) when batched
    # ensure channel dim (B, 1, n_mels, time) for CNN
    if specs.dim() == 3:
        specs = specs.unsqueeze(1)
    return specs

def evaluate(model, loader, device):
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    with torch.no_grad():
        for batch in loader:
            # batch is (wav, label) tuples when DataLoader has default collate; it will produce tensors stacked
            wavs, labels = batch
            # compute specs on device
            specs = batch_to_device_and_spec(wavs, device)
            labels = labels.to(device, dtype=torch.long)
            logits = model(specs)
            loss = criterion(logits, labels)
            loss_sum += loss.item() * specs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += specs.size(0)
    return loss_sum / total, correct / total


best_val = 0.0
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    epoch_samples = 0
    for batch in train_loader:
        wavs, labels = batch
        specs = batch_to_device_and_spec(wavs, DEVICE)
        labels = labels.to(DEVICE, dtype=torch.long)
        optimizer.zero_grad()
        logits = model(specs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * specs.size(0)
        epoch_samples += specs.size(0)
    train_loss = epoch_loss / epoch_samples if epoch_samples>0 else 0.0
    val_loss, val_acc = evaluate(model, val_loader, DEVICE)
    print(f"Epoch {epoch+1}/{EPOCHS} - train_loss: {train_loss:.4f} val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}")
    # checkpoint best model by val_acc
    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'epoch': epoch
        }, checkpoint_path)



In [None]:
# wavs, labels = next(itr)
# specs = batch_to_device_and_spec(wavs, DEVICE)[0]  # first example’s spec
# print(specs.shape)

# prediction = model(specs)
# print(prediction)

# plot_spectrogram(specs, title=f"True label: {LABELS[labels[0]]}; Predicted: {0}")