In [2]:
!pip install librosa 

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Using cached audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-0.5.0.post1-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Collecting lazy_loader>=0.1 (from librosa)
  Downloading lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting llvmlite<0.45,>=0.44.0dev0 (

In [3]:
import os
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


In [4]:
AUDIO_DIR = 'audio'
SPEC_DIR = 'spec_images'
META_PATH = 'metadata.csv'

WINDOW_SIZE = 3.0  # seconds
HOP_SIZE = 1.5     # seconds
SR = 22050         # sampling rate
IMG_SIZE = 128     # spectrogram image size


In [7]:
os.makedirs(SPEC_DIR, exist_ok=True)
meta_df = pd.read_csv(META_PATH)  # columns: filename,label

label_map = {label: idx for idx, label in enumerate(meta_df['diagnosis'].unique())}
label_df = meta_df.set_index('id')['diagnosis'].to_dict()

def save_spec_patch(y, sr, out_path):
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=512, n_mels=128)
    S_dB = librosa.power_to_db(S, ref=np.max)
    S_img = Image.fromarray(S_dB).resize((IMG_SIZE, IMG_SIZE))
    S_img = S_img.convert('L')  # grayscale
    S_img.save(out_path)

for fname in tqdm(os.listdir(AUDIO_DIR)):
    if not fname.endswith('.wav'): continue
    file_id = os.path.splitext(fname)[0]
    y, _ = librosa.load(os.path.join(AUDIO_DIR, fname), sr=SR)
    duration = librosa.get_duration(y=y, sr=SR)
    
    win_len = int(WINDOW_SIZE * SR)
    hop_len = int(HOP_SIZE * SR)
    output_dir = os.path.join(SPEC_DIR, file_id)
    os.makedirs(output_dir, exist_ok=True)

    for i, start in enumerate(np.arange(0, duration - WINDOW_SIZE, HOP_SIZE)):
        s = int(start * SR)
        y_win = y[s : s + win_len]
        out_path = os.path.join(output_dir, f"{file_id}_{i}.png")
        save_spec_patch(y_win, SR, out_path)


100%|██████████| 1843/1843 [02:47<00:00, 10.99it/s]


In [8]:
class SpecPatchDataset(Dataset):
    def __init__(self, root_dir, label_dict, transform=None):
        self.samples = []
        self.labels = []
        self.transform = transform or transforms.ToTensor()

        for file_id in os.listdir(root_dir):
            class_label = label_dict[file_id + ".wav"]
            label_idx = label_map[class_label]
            file_folder = os.path.join(root_dir, file_id)
            for img in os.listdir(file_folder):
                self.samples.append(os.path.join(file_folder, img))
                self.labels.append(label_idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert('L')
        img = self.transform(img)
        label = self.labels[idx]
        return img, label


In [9]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(64 * 13 * 13, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # ← flatten
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

file_ids = os.listdir(AUDIO_DIR)
train_ids, val_ids = train_test_split(file_ids, test_size=0.2, stratify=meta_df['diagnosis'])

train_labels = {k: label_df[k] for k in train_ids}
val_labels = {k: label_df[k] for k in val_ids}

train_ds = SpecPatchDataset(SPEC_DIR, train_labels, transform)
val_ds   = SpecPatchDataset(SPEC_DIR, val_labels, transform)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)


ValueError: Found input variables with inconsistent numbers of samples: [1843, 126]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(label_map)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 16):
    model.train()
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch:02d} | Train Loss: {total_loss / len(train_loader):.4f}")


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Validation Accuracy: {correct / total:.4f}")
