In [2]:
import math
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchaudio
from torch.utils.data import DataLoader, Dataset, random_split
from torchaudio import transforms
from pathlib import Path
from IPython.display import Audio
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

def install_libraries():
    !pip install pandas torch torchaudio torchvision

def prepare_metadata(directory_path: str) -> pd.DataFrame:
    metadata_file = Path(directory_path) / 'metadata' / 'UrbanSound8K.csv'
    metadata = pd.read_csv(metadata_file)
    metadata['path'] = metadata['slice_file_name'].apply(lambda x: f"/fold{metadata.loc[metadata['slice_file_name'] == x, 'fold'].iloc[0]}/{x}")
    metadata = metadata[['path', 'classID']]
    metadata = metadata.rename(columns={'path': 'relative_path'})
    return metadata

class AudioAugment:

    # Insert an audio file. Return the signal as a tensor as well as the sampling rate.
    @staticmethod
    def open(wav_file):
        sig, sampling_rate = torchaudio.load(wav_file)
        return (sig, sampling_rate)

    # Convert the audio provided to the desired number of channels
    @staticmethod
    def rechannel(aud, new_channel):
        sig, sampling_rate = aud

        if sig.shape[0] == new_channel:
            return aud

        if new_channel == 1:
            # Convert first channel to mono.
            resig = sig[:1, :]
        else:
            # Convert from mono to stereo.
            resig = torch.cat([sig, sig])

        return (resig, sampling_rate)

    # Because Resample only applies to one channel, we resample one channel at a time.
    @staticmethod
    def resample(aud, newsr):
        sig, sampling_rate = aud

        if sampling_rate == newsr:
            return aud

        num_channels = sig.shape[0]
        # First channel resampling
        resig = torchaudio.transforms.Resample(sampling_rate, newsr)(sig[:1, :])
        if num_channels > 1:
            # Resample the second channel and merge it with the first.
            retwo = torchaudio.transforms.Resample(sampling_rate, newsr)(sig[1:, :])
            resig = torch.cat([resig, retwo])

        return (resig, newsr)

    # Truncate or Pad the signal to a fixed length in milliseconds ('maximum audio length').
    @staticmethod
    def pad_trunc(aud, maximum_audio_length):
        sig, sampling_rate = aud
        num_rows, input_signal_length = sig.shape
        maximum_length = sampling_rate // 1000 * maximum_audio_length

        if input_signal_length > maximum_length:
            # Reduce the signal to the specified length.
            sig = sig[:, :maximum_length]

        elif input_signal_length < maximum_length:
            # Padding length to be added at the beginning and end of the signal
            padding_begin_length = random.randint(0, maximum_length - input_signal_length)
            padding_end_length = maximum_length - input_signal_length - padding_begin_length

            # Pad with 0s
            pad_begin = torch.zeros((num_rows, padding_begin_length))
            pad_end = torch.zeros((num_rows, padding_end_length))

            sig = torch.cat((pad_begin, sig, pad_end), 1)

        return (sig, sampling_rate)

    # Shifts the signal by a percentage to the left or right. End values are 'wrapped around' to the beginning of the transformed signal.
    @staticmethod
    def time_shift(aud, shift_limit):
        sig, sampling_rate = aud
        _, input_signal_length = sig.shape
        shift_amt = int(random.random() * shift_limit * input_signal_length)
        return (sig.roll(shift_amt), sampling_rate)

    # Create a Spectrogram
    @staticmethod
    def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
        sig, sampling_rate = aud
        top_db = 80

        # The shape of spec is [channel, n mels, time], where channel is mono, stereo, and so on.
        spec = transforms.MelSpectrogram(
            sampling_rate, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels
        )(sig)

        # Decibel conversion
        spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
        return spec

    # Mask out some sections of the Spectrogram in both the frequency dimension (horizontal bars) and the time dimension (vertical bars) to prevent overfitting and help the model generalize better. The mean value is used to replace the masked sections.
    @staticmethod
    def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
        _, n_mels, n_steps = spec.shape
        mask_value = spec.mean()
        aug_spec = spec

        freq_mask_param = max_mask_pct * n_mels
        for _ in range(n_freq_masks):
            aug_spec = transforms.FrequencyMasking(freq_mask_param)(
                aug_spec, mask_value
            )

        time_mask_param = max_mask_pct * n_steps
        for _ in range(n_time_masks):
            aug_spec = transforms.TimeMasking(time_mask_param)(aug_spec, mask_value)

        return aug_spec

class AudioDataset(Dataset):
    def __init__(self, df, audio_dir):
        self.df = df
        self.audio_dir = str(audio_dir)
        self.length = 4000
        self.rate = 44100
        self.channels = 2
        self.shift_percentage = 0.4

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        audio_filepath = self.audio_dir + self.df.loc[index, "relative_path"]
        label = self.df.loc[index, "classID"]

        audio = AudioAugment.open(audio_filepath)
        resampled_audio = AudioAugment.resample(audio, self.rate)
        rechanneled_audio = AudioAugment.rechannel(resampled_audio, self.channels)

        padded_audio = AudioAugment.pad_trunc(rechanneled_audio, self.length)
        shifted_audio = AudioAugment.time_shift(padded_audio, self.shift_percentage)
        spectrogram = AudioAugment.spectro_gram(shifted_audio, n_mels=64, n_fft=1024, hop_len=None)
        augmented_spectrogram = AudioAugment.spectro_augment(
            spectrogram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2
        )

        return augmented_spectrogram, label

def generate_data_loaders(df, a_path):
    dataset = AudioDataset(df, a_path)

    total_items = len(dataset)
    total_items = len(dataset)
    train_items = round(total_items * 0.8)
    val_items = total_items - train_items
    training_set, validation_set = random_split(dataset, [train_items, val_items])

    train_data_loader = torch.utils.data.DataLoader(training_set, batch_size=16, shuffle=True)
    val_data_loader = torch.utils.data.DataLoader(validation_set, batch_size=16, shuffle=False)

    return train_data_loader, val_data_loader   

class SoundClassifier(nn.Module):
    def __init__(self):
        super(SoundClassifier, self).__init__()
        self.convs = nn.Sequential(
            self._conv_block(2, 8, 5, 2, 2),
            self._conv_block(8, 16, 3, 2, 1),
            self._conv_block(16, 32, 3, 2, 1),
            self._conv_block(32, 64, 3, 2, 1)
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, 10)

    def _conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        relu = nn.ReLU()
        bn = nn.BatchNorm2d(out_channels)
        init.kaiming_normal_(conv.weight, a=0.1)
        conv.bias.data.zero_()
        return nn.Sequential(conv, relu, bn)

    def forward(self, x):
        x = self.convs(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def fit_model(mdl, trn_ldr, n_epochs, dev):
    # Set up necessary components
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(mdl.parameters(), lr=1e-3)
    sched = OneCycleLR(opt, max_lr=1e-3, steps_per_epoch=len(trn_ldr), epochs=n_epochs, anneal_strategy="linear")

    # Train for each epoch
    for ep in range(n_epochs):
        total_loss, correct, total_preds = 0.0, 0, 0

        # Process each batch in the training set
        for idx, batch_data in enumerate(trn_ldr):
            # Move input and target data to the device
            inputs, batch_labels = batch_data[0].to(dev), batch_data[1].to(dev)

            # Normalize input data
            mean, std = inputs.mean(), inputs.std()
            inputs = (inputs - mean) / std

            # Zero the parameter gradients
            opt.zero_grad()

            # Perform forward pass, calculate loss, and update weights
            preds = mdl(inputs)
            loss = loss_fn(preds, batch_labels)
            loss.backward()
            opt.step()
            sched.step()

            # Update loss and accuracy
            total_loss += loss.item()
            _, predicted_class = torch.max(preds, 1)
            correct += (predicted_class == batch_labels).sum().item()
            total_preds += predicted_class.shape[0]

        # Print statistics for the current epoch
        avg_loss = total_loss / len(trn_ldr)
        acc = correct / total_preds
        print(f"Epoch: {ep}, Loss: {avg_loss:.2f}, Accuracy: {acc:.2f}")

    print("Training Complete")

def test_inference(model, validation_dataloader, device):
    number_correct = 0
    num_examples = 0

    # Turn off gradient updates.
    with torch.no_grad():
        for batch in validation_dataloader:
            # Place the input features and target labels on the device (GPU or CPU).
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            # Standardize inputs.
            inputs = (inputs - inputs.mean()) / inputs.std()

            # Get predicted outputs.
            outputs = model(inputs)

            # Get the class with the highest predicted score.
            predictions = torch.argmax(outputs, dim=1)

            # The number of correct predictions.
            number_correct += (predictions == labels).sum().item()

            # The total number of examples.
            num_examples += inputs.size(0)

    accuracy = number_correct / num_examples
    print(f"Accuracy: {accuracy:.2f}, Total items: {num_examples}")


def main():
    install_libraries()

    base_directory = Path.cwd() / 'UrbanSound8K'
    dataframe = prepare_metadata(base_directory)
    audio_path = base_directory / 'audio'
    train_dataloader, validation_dataloader = generate_data_loaders(dataframe, audio_path)

    audioClassifierModel = SoundClassifier()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    audioClassifierModel = audioClassifierModel.to(device)

    number_of_epochs = 2
    fit_model(audioClassifierModel, train_dataloader, number_of_epochs, device)
    test_inference(audioClassifierModel, validation_dataloader, device)

if __name__ == "__main__":
    main()

Epoch: 0, Loss: 1.88, Accuracy: 0.33
Epoch: 1, Loss: 1.51, Accuracy: 0.48
Training Complete
Accuracy: 0.51, Total items: 1746
