In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import numpy as np
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import re
from tqdm.notebook import tqdm
import ipywidgets as widgets
widgets.IntProgress(value=50, min=0, max=100)
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib notebook

Load Dataset

In [2]:
class BinauralCueDataset(Dataset):
    def __init__(self, npz_dir, audio_ids=range(1, 101)):
        self.dir = npz_dir
        pattern = re.compile(r'main_audio_(\d+)_azi(\d+)\.npz')
        self.files = []
        for f in os.listdir(npz_dir):
            if f.endswith('.npz'):
                match = pattern.match(f)
                if match and int(match.group(1)) in audio_ids:
                    self.files.append(f)
        self.files.sort()

        print(f"üìÅ Â∑≤Âä†ËΩΩ {len(self.files)} ‰∏™ .npz Êñá‰ª∂ÔºåÂÖ± {len(self)} ‰∏™Ê†∑Êú¨„ÄÇ")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = os.path.join(self.dir, self.files[idx])
        data = np.load(path)
        itd = data["itd"].astype(np.float32)
        ild = data["ild"].astype(np.float32)
        ic = data["ic"].astype(np.float32)

        cue = np.stack([itd, ild, ic], axis=0)  # [3, filters, frames]

        # ÊèêÂèñ azimuth label
        azimuth = int(re.search(r'azi(\d+)', self.files[idx]).group(1))
        label = azimuth // 5  # ÂÖ±72Á±ªÔºà0-71Ôºâ

        return cue, label

In [3]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        stride = 2 if downsample else 1
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = downsample
        if downsample or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

class AzimuthResNetCNN(nn.Module):
    def __init__(self, num_classes=72):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer1 = ResBlock(32, 64, downsample=True)
        self.layer2 = ResBlock(64, 128, downsample=True)
        self.layer3 = ResBlock(128, 128)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)        # shape: [B, 128, 1, 1]
        x = torch.flatten(x, 1) # shape: [B, 128]
        x = self.dropout(x)
        return self.fc(x)

In [4]:
full_dataset = BinauralCueDataset(r"C:\Users\TIANY1\OneDrive - Trinity College Dublin\Documents\SoundSourceLocalization\features")
train_dataset, val_dataset = random_split(full_dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model = AzimuthResNetCNN(num_classes=72)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The device type is: ", device)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.02)
criterion = nn.CrossEntropyLoss()

üìÅ Â∑≤Âä†ËΩΩ 5360 ‰∏™ .npz Êñá‰ª∂ÔºåÂÖ± 5360 ‰∏™Ê†∑Êú¨„ÄÇ
The device type is:  cuda


Training

In [5]:
num_epochs = 20
best_acc1 = 0.0
best_epoch = 0
best_model_wts = None

train_loss_history, val_loss_history = [], []
train_acc1_history, val_acc1_history = [], []
train_acc5_history, val_acc5_history = [], []

os.makedirs(r"C:\Users\TIANY1\OneDrive - Trinity College Dublin\Documents\SoundSourceLocalization\checkpoints", exist_ok=True)

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].set_title("Loss Curve")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
train_loss_line, = axs[0].plot([], [], label='Train Loss')
val_loss_line, = axs[0].plot([], [], label='Val Loss')
axs[0].legend()

axs[1].set_title("Accuracy Curve")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy")
train_acc1_line, = axs[1].plot([], [], 'b-', label='Top-1 Train')
val_acc1_line, = axs[1].plot([], [], 'b--', label='Top-1 Val')
train_acc5_line, = axs[1].plot([], [], 'r-', label='Top-5 Train')
val_acc5_line, = axs[1].plot([], [], 'r--', label='Top-5 Val')
axs[1].legend()

plt.tight_layout()

for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss_sum = 0.0
    train_correct_top1 = 0
    train_correct_top5 = 0
    train_total = 0

    for cue, azi in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
        cue, azi = cue.to(device), azi.to(device)
        optimizer.zero_grad()
        outputs = model(cue)
        loss = criterion(outputs, azi)
        loss.backward()
        optimizer.step()

        train_loss_sum += loss.item() * cue.size(0)
        train_total += cue.size(0)

        # Top-1 and Top-5 Accuracy
        _, pred_top1 = outputs.max(dim=1)
        train_correct_top1 += (pred_top1 == azi).sum().item()

        top5_val, top5_idx = outputs.topk(5, dim=1)
        diff = (top5_idx - azi.view(-1, 1)).abs()
        diff = torch.minimum(diff, 72 - diff)
        train_correct_top5 += (diff <=2).any(dim=1).sum().item()

    train_loss_avg = train_loss_sum / train_total
    train_acc1 = train_correct_top1 / train_total
    train_acc5 = train_correct_top5 / train_total

    # Evaluation
    model.eval()
    val_loss_sum = 0.0
    val_correct_top1 = 0
    val_correct_top5 = 0
    val_total = 0

    with torch.no_grad():
        for cue, azi in tqdm(val_loader, desc="Evaluating...", leave=False):
            cue, azi = cue.to(device), azi.to(device)
            outputs = model(cue)
            loss = criterion(outputs, azi)

            val_loss_sum += loss.item() * cue.size(0)
            val_total += cue.size(0)
            _, pred_top1 = outputs.max(dim=1)
            val_correct_top1 += (pred_top1 == azi).sum().item()
            top5_vals, top5_idx = outputs.topk(5, dim=1)
            diff = (top5_idx - azi.view(-1, 1)).abs()
            diff = torch.minimum(diff, 72 - diff)
            val_correct_top5 += (diff <= 2).any(dim=1).sum().item()

    val_loss_avg = val_loss_sum / val_total
    val_acc1 = val_correct_top1 / val_total
    val_acc5 = val_correct_top5 / val_total

    # Record results
    train_loss_history.append(train_loss_avg)
    val_loss_history.append(val_loss_avg)
    train_acc1_history.append(train_acc1)
    val_acc1_history.append(val_acc1)
    train_acc5_history.append(train_acc5)
    val_acc5_history.append(val_acc5)

    # ÂÆûÊó∂Êõ¥Êñ∞ÂõæÂÉè
    epochs_range = range(1, epoch + 1)
    train_loss_line.set_data(epochs_range, train_loss_history)
    val_loss_line.set_data(epochs_range, val_loss_history)
    train_acc1_line.set_data(epochs_range, train_acc1_history)
    val_acc1_line.set_data(epochs_range, val_acc1_history)
    train_acc5_line.set_data(epochs_range, train_acc5_history)
    val_acc5_line.set_data(epochs_range, val_acc5_history)

    # Ëá™ÈÄÇÂ∫îÂùêÊ†áËΩ¥ËåÉÂõ¥
    axs[0].relim(); axs[0].autoscale_view()
    axs[1].relim(); axs[1].autoscale_view()

    plt.pause(0.01)  # ËÆ©ÂõæÂÉèÂà∑Êñ∞

    print(f"Epoch {epoch}/{num_epochs}: Train Loss={train_loss_avg:.4f}, Top-1={train_acc1*100:.2f}%, Top-5={train_acc5*100:.2f}% | "
          f"Evaluation Loss={val_loss_avg:.4f}, Top-1={val_acc1*100:.2f}%, Top-5={val_acc5*100:.2f}%")
    
    # Checkpoints
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }

    checkpoint_path = os.path.join(
        r"C:\Users\TIANY1\OneDrive - Trinity College Dublin\Documents\SoundSourceLocalization\checkpoints",
        f"epoch_{epoch}.pth"
    )
    torch.save(checkpoint, checkpoint_path)
    
    if val_acc1 > best_acc1:
        best_acc1 = val_acc1
        best_epoch = epoch
        best_model_wts = model.state_dict()

if best_model_wts is not None:
    torch.save(best_model_wts, os.path.join(
        r"C:\Users\TIANY1\OneDrive - Trinity College Dublin\Documents\SoundSourceLocalization\checkpoints",
        "best_model.pth"
    ))
    print(f"The best model appears in epoch {best_epoch}, and the Validation Top-1 Accuracy is {best_acc1*100:.2f}%, save as best_model.pth")

<IPython.core.display.Javascript object>

Epoch 1.20:   0%|          | 0/268 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/67 [00:00<?, ?it/s]

Epoch 1/20: Train Loss=4.0021, Top-1=3.73%, Top-5=53.68% | Evaluation Loss=3.7844, Top-1=5.22%, Top-5=57.56%


Epoch 2.20:   0%|          | 0/268 [00:00<?, ?it/s]