In [2]:
# Install necessary libraries
# !pip install librosa torch torchvision numpy

# Import libraries
import os
import numpy as np
import librosa
import torch
import torchvision.transforms as T

In [6]:
# Define the EnhancedBinaryClassifier model
class EnhancedBinaryClassifier(torch.nn.Module):
    def __init__(self, input_height=128, input_width=128):
        super(EnhancedBinaryClassifier, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        
        self.flattened_size = self._get_flattened_size(input_height, input_width)
        self.fc1 = torch.nn.Linear(self.flattened_size, 1)

    def _get_flattened_size(self, height, width):
        with torch.no_grad():
            x = torch.zeros(1, 1, height, width)
            x = self.pool(torch.relu(self.conv1(x)))
            x = self.pool(torch.relu(self.conv2(x)))
            return x.numel()

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.sigmoid(self.fc1(x))
        return x

In [8]:
# Load the trained model
model_path = 'trained_model.pth'  # Path to your trained model
model = EnhancedBinaryClassifier(input_height=128, input_width=128)
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode

  model.load_state_dict(torch.load(model_path))


EnhancedBinaryClassifier(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=65536, out_features=1, bias=True)
)

In [10]:
def process_and_predict(file_path):
    # Parameters for spectrogram generation
    sr = 44100  # Sample rate
    n_mels = 128  # Number of Mel bands
    hop_length = 512  # Hop length
    snippet_duration = 5  # Duration of each snippet in seconds
    
    # Resize transform to ensure spectrograms are 128x128
    resize_transform = T.Compose([
        T.ToPILImage(),
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    # Load the audio file
    try:
        y, _ = librosa.load(file_path, sr=sr)
    except Exception as e:
        raise ValueError(f"Failed to load {file_path}: {e}")

    # Calculate number of samples per snippet
    snippet_samples = int(snippet_duration * sr)

    # Split audio into 5-second snippets and predict
    predictions = []
    for start_sample in range(0, len(y) - snippet_samples + 1, snippet_samples):
        end_sample = start_sample + snippet_samples
        snippet = y[start_sample:end_sample]

        # Generate the Mel spectrogram
        S = librosa.feature.melspectrogram(y=snippet, sr=sr, n_mels=n_mels, hop_length=hop_length)
        S_dB = librosa.power_to_db(S, ref=np.max)  # Convert to dB scale

        # Convert spectrogram to tensor and resize
        spectrogram_tensor = torch.tensor(S_dB, dtype=torch.float32)
        spectrogram_tensor = resize_transform(spectrogram_tensor.unsqueeze(0)).squeeze(0)

        # Add batch and channel dimensions for the model
        spectrogram_tensor = spectrogram_tensor.unsqueeze(0).unsqueeze(0)

        # Predict with the model
        with torch.no_grad():
            output = model(spectrogram_tensor).item()
            predictions.append(output)

    # Average the predictions
    average_prediction = np.mean(predictions)
    label = "clean" if average_prediction < 0.5 else "distorted"
    
    return label, predictions

In [14]:
# Example usage
file_path = 'clean 1 (40 bpm) open G.wav'  # Path to any 60-second .wav file, replace with any .wav to test if it guesses correctly
result, snippet_predictions = process_and_predict(file_path)
print(f"The audio file is predicted to be: {result}")
print(f"Snippet predictions: {snippet_predictions}")

The audio file is predicted to be: clean
Snippet predictions: [0.19373087584972382, 0.12644042074680328, 0.05742713063955307, 0.11891155689954758, 0.08172749727964401, 0.12470465898513794, 0.0909908190369606, 0.07392887771129608, 0.03425827622413635, 0.07056822627782822, 0.056865718215703964, 0.03679925948381424]
