In [20]:
import os
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


In [21]:
AUDIO_DIR = 'audio'
SPEC_DIR = 'spec_images'
META_PATH = 'metadata.csv'

WINDOW_SIZE = 3.0  # seconds
HOP_SIZE = 1.5     # seconds
SR = 22050         # sampling rate
IMG_SIZE = 128     # spectrogram image size


In [None]:
os.makedirs(SPEC_DIR, exist_ok=True)
meta_df = pd.read_csv(META_PATH)  # columns: filename,label

label_map = {label: idx for idx, label in enumerate(meta_df['diagnosis'].unique())}
label_df = meta_df.set_index('id')['diagnosis'].to_dict()

def save_spec_patch(y, sr, out_path):
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=512, n_mels=128)
    S_dB = librosa.power_to_db(S, ref=np.max)
    S_img = Image.fromarray(S_dB).resize((IMG_SIZE, IMG_SIZE))
    S_img = S_img.convert('L')  # grayscale
    S_img.save(out_path)

for fname in tqdm(os.listdir(AUDIO_DIR)):
    if not fname.endswith('.wav'): continue
    file_id = os.path.splitext(fname)[0]
    y, _ = librosa.load(os.path.join(AUDIO_DIR, fname), sr=SR)
    duration = librosa.get_duration(y=y, sr=SR)
    
    win_len = int(WINDOW_SIZE * SR)
    hop_len = int(HOP_SIZE * SR)
    output_dir = os.path.join(SPEC_DIR, file_id)
    os.makedirs(output_dir, exist_ok=True)

    for i, start in enumerate(np.arange(0, duration - WINDOW_SIZE, HOP_SIZE)):
        s = int(start * SR)
        y_win = y[s : s + win_len]
        out_path = os.path.join(output_dir, f"{file_id}_{i}.png")
        save_spec_patch(y_win, SR, out_path)


100%|██████████| 1843/1843 [03:56<00:00,  7.80it/s]


In [23]:
class SpecPatchDataset(Dataset):
    def __init__(self, root_dir, label_dict, transform=None):
        self.samples = []
        self.labels = []
        self.transform = transform or transforms.ToTensor()

        for file_id in os.listdir(root_dir):
            class_label = label_dict[file_id + ".wav"]
            label_idx = label_map[class_label]
            file_folder = os.path.join(root_dir, file_id)
            for img in os.listdir(file_folder):
                self.samples.append(os.path.join(file_folder, img))
                self.labels.append(label_idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert('L')
        img = self.transform(img)
        label = self.labels[idx]
        return img, label


In [24]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(64 * 13 * 13, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # ← flatten
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


In [25]:
print(meta_df['diagnosis'].value_counts())

diagnosis
COPD              64
Healthy           26
URTI              14
Bronchiectasis     7
Bronchiolitis      6
Pneumonia          6
LRTI               2
Asthma             1
Name: count, dtype: int64


In [26]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Get all WAV files
file_ids = [f for f in os.listdir(AUDIO_DIR) if f.endswith('.wav')]

# Extract patient IDs from filenames (first part before '_')
patient_ids = [f.split('_')[0] for f in file_ids]

# Map each file to its patient's diagnosis
stratify_labels = []
for pid in patient_ids:
    try:
        diagnosis = meta_df.loc[meta_df['id'] == int(pid), 'diagnosis'].values[0]
        stratify_labels.append(diagnosis)
    except:
        print(f"Warning: No diagnosis found for patient {pid}")
        stratify_labels.append('Unknown')  # Handle missing cases

# Split with aligned labels
train_ids, val_ids = train_test_split(
    file_ids,
    test_size=0.2, 
    stratify=stratify_labels
)

# Update label dictionaries
train_labels = {fname: meta_df.loc[meta_df['id'] == int(fname.split('_')[0]), 'diagnosis'].values[0]
                for fname in train_ids}
val_labels = {fname: meta_df.loc[meta_df['id'] == int(fname.split('_')[0]), 'diagnosis'].values[0]
              for fname in val_ids}

# Create datasets
train_ds = SpecPatchDataset(SPEC_DIR, train_labels, transform)
val_ds = SpecPatchDataset(SPEC_DIR, val_labels, transform)

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(label_map)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 16):
    model.train()
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch:02d} | Train Loss: {total_loss / len(train_loader):.4f}")


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Validation Accuracy: {correct / total:.4f}")
