In [2]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from IPython.display import Audio
import ipywidgets as widgets

# -------------------- Config --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample_rate = 22050
n_mels = 64

class CNNVoiceClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(CNNVoiceClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,32,3,1,1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(32,64,3,1,1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.AdaptiveAvgPool2d((4,4))
        self.fc1 = nn.Linear(64*4*4,128)
        self.fc2 = nn.Linear(128,num_classes)
    
    def forward(self,x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Load the entire saved model
model = CNNVoiceClassifier(num_classes=2)
model.load_state_dict(torch.load("cnn_voice_classifier_state.pth", map_location=device))
model.to(device)
model.eval()

# Mel spectrogram transform
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=1024,
    hop_length=512,
    n_mels=n_mels
)

# -------------------- Helper functions --------------------
def chunk_waveform(waveform, sr, chunk_sec=5):
    chunk_len = chunk_sec * sr
    chunks = []
    total_len = waveform.shape[-1]
    for start in range(0, total_len, chunk_len):
        end = start + chunk_len
        chunk = waveform[:, start:end]
        if chunk.shape[-1] < chunk_len:
            pad = torch.zeros((waveform.shape[0], chunk_len - chunk.shape[-1]))
            chunk = torch.cat([chunk, pad], dim=-1)
        chunks.append(chunk)
    return chunks

def predict_parkinsons(waveform):
    chunks = chunk_waveform(waveform, sample_rate)
    preds = []
    for chunk in chunks:
        mel = mel_transform(chunk)
        mel_db = torchaudio.functional.amplitude_to_DB(
            mel, multiplier=10.0, amin=1e-10, db_multiplier=0
        )
        mel_db = mel_db.unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(mel_db)
            prob = torch.softmax(output, dim=1)[:, 1].item()  # unhealthy prob
            preds.append(prob)
    avg_prob = np.mean(preds)
    return avg_prob, preds


In [3]:
uploader = widgets.FileUpload(accept='.wav', multiple=False)
display(uploader)

FileUpload(value=(), accept='.wav', description='Upload')

In [9]:
if len(uploader.value) > 0:
    # Check type: tuple or dict
    uploaded_item = uploader.value[0] if isinstance(uploader.value, tuple) else list(uploader.value.values())[0]
    
    audio_bytes = uploaded_item['content']
    
    # Save to temporary file
    temp_path = "temp_user_voice.wav"
    with open(temp_path, "wb") as f:
        f.write(audio_bytes)
    
    # Load waveform
    waveform, sr = torchaudio.load(temp_path)

    # Convert to mono BEFORE resampling or prediction
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Resample if needed
    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)

    # Now waveform has shape [1, num_samples] and can go to CNN
    avg_prob, chunk_probs = predict_parkinsons(waveform)
    
    # Play audio in notebook
    display(Audio(temp_path))
    
    # Now you can call your prediction function
    print(f"⚡ Parkinson's likelihood: {avg_prob*100:.2f}%")


⚡ Parkinson's likelihood: 0.27%
