In [1]:
import os
import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as T
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report

# --- CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SFM_CSV = "sfm_features.csv"
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 20
IMG_SIZE = 128 # Resize spectrograms to 128x128 squares

print(f"ðŸš€ Running Basic CNN on: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


ðŸš€ Running Basic CNN on: cuda


In [2]:
class CNNDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        
        # Audio Processing
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft=1024,
            win_length=1024,
            hop_length=160,
            n_mels=128
        )
        self.db_transform = torchaudio.transforms.AmplitudeToDB()
        self.resize = T.Resize((IMG_SIZE, IMG_SIZE)) # Force square shape

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_path = row['file_path']
        
        # 1. Load Audio
        try:
            waveform, sr = torchaudio.load(file_path)
        except:
            waveform = torch.zeros(1, 16000*3)
            sr = 16000
            
        # Resample
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            
        # 2. Spectrogram -> Image
        mel_spec = self.mel_transform(waveform)
        log_mel = self.db_transform(mel_spec)
        
        # Resize to (1, 128, 128)
        img = self.resize(log_mel)
        
        # Normalize (approximate mean/std for stability)
        img = (img - img.mean()) / (img.std() + 1e-6)

        # 3. Label
        label = torch.tensor(row['label_encoded'], dtype=torch.long)
        
        return img, label

In [3]:
class BasicCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        # Block 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2) # Output: 64x64
        )
        
        # Block 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2) # Output: 32x32
        )
        
        # Block 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2) # Output: 16x16
        )
        
        # Classifier
        # Input size = 128 channels * 16 * 16 spatial dim
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(128 * 16 * 16, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        logits = self.fc(x)
        return logits

In [4]:
# 1. Setup Data
full_df = pd.read_csv(SFM_CSV)
label_encoder = LabelEncoder()
full_df['label_encoded'] = label_encoder.fit_transform(full_df['label_name'])
NUM_CLASSES = len(label_encoder.classes_)

train_ds = CNNDataset(full_df[full_df['split'] == 'train'])
val_ds = CNNDataset(full_df[full_df['split'] == 'val'])
test_ds = CNNDataset(full_df[full_df['split'] == 'test'])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# 2. Setup Model
cnn_model = BasicCNN(num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(cnn_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

print("ðŸ”¥ Starting Basic CNN Training...")
best_acc = 0.0

# 3. Training Loop
for epoch in range(EPOCHS):
    cnn_model.train()
    total_loss = 0
    
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        logits = cnn_model(imgs)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    # Validation
    cnn_model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            logits = cnn_model(imgs)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    val_acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss/len(train_loader):.4f} | Val Acc: {val_acc:.4f}")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(cnn_model.state_dict(), "baseline_basic_cnn.pth")

# --- FINAL EVALUATION ---
print("\nðŸ“Š Evaluating Basic CNN on Test Set...")
cnn_model.load_state_dict(torch.load("baseline_basic_cnn.pth"))
cnn_model.eval()

all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        logits = cnn_model(imgs)
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print(classification_report(all_labels, all_preds, target_names=label_encoder.classes_))

ðŸ”¥ Starting Basic CNN Training...




Epoch 1/20 | Loss: 2.0427 | Val Acc: 0.7404




Epoch 2/20 | Loss: 1.0061 | Val Acc: 0.7668




Epoch 3/20 | Loss: 0.8992 | Val Acc: 0.7788




Epoch 4/20 | Loss: 0.8609 | Val Acc: 0.8029




Epoch 5/20 | Loss: 0.8249 | Val Acc: 0.8077




Epoch 6/20 | Loss: 0.7723 | Val Acc: 0.8269




Epoch 7/20 | Loss: 0.7455 | Val Acc: 0.8293




Epoch 8/20 | Loss: 0.7475 | Val Acc: 0.8173




Epoch 9/20 | Loss: 0.7145 | Val Acc: 0.8269




Epoch 10/20 | Loss: 0.6769 | Val Acc: 0.8197




Epoch 11/20 | Loss: 0.6554 | Val Acc: 0.8269




Epoch 12/20 | Loss: 0.6310 | Val Acc: 0.8341




Epoch 13/20 | Loss: 0.6239 | Val Acc: 0.8413




Epoch 14/20 | Loss: 0.5953 | Val Acc: 0.8293




Epoch 15/20 | Loss: 0.5523 | Val Acc: 0.8534




Epoch 16/20 | Loss: 0.5183 | Val Acc: 0.8486




Epoch 17/20 | Loss: 0.5343 | Val Acc: 0.8486




Epoch 18/20 | Loss: 0.4916 | Val Acc: 0.8582




Epoch 19/20 | Loss: 0.4948 | Val Acc: 0.8558




Epoch 20/20 | Loss: 0.4651 | Val Acc: 0.8822

ðŸ“Š Evaluating Basic CNN on Test Set...




                        precision    recall  f1-score   support

      Cysts_Structural       0.80      0.18      0.30        22
             Dysarthia       1.00      1.00      1.00        42
            Laryngitis       0.60      0.83      0.70        42
           Vox senilis       0.79      0.89      0.84        93
             parkinson       0.98      0.97      0.98       200
spasmodische_dysphonie       0.80      0.40      0.53        20

              accuracy                           0.88       419
             macro avg       0.83      0.71      0.72       419
          weighted avg       0.88      0.88      0.86       419

