In [1]:
import torch
import numpy as np
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
from torchaudio.transforms import MelSpectrogram
from torch.utils.data import Dataset

BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001

LABELS_PATH = "Data/train_label.txt"
TRAINING_DATA_PATH = "Data/train_mp3s"

In [None]:
class MusicDataset(Dataset):
    def __init__(self,audio_dir_path, labels_path, mel_spectrogram):
        self.labels = self.read_file_to_numpy_array(labels_path).astype(int)
        self.audio_dir_path = audio_dir_path
        self.mel_spectrogram = mel_spectrogram
        
        
    def __getitem__(self, index):
        mp3_path = f'{self.audio_dir_path}/{index}.mp3'
        spectrogram = self.audio_to_spectrogram(mp3_path)
        label = self.labels[index]
        return spectrogram, int(label)
        
        
    def read_file_to_numpy_array(self, filename):
        with open(filename, 'r') as file:
            lines = file.read().splitlines()
            lines_array = np.array(lines)
        return lines_array
    
    
    def audio_to_spectrogram(self, file_path):
        # Load audio file
        waveform, _ = torchaudio.load(file_path)
        # Apply transform to waveform
        mel_spec = self.mel_spectrogram(waveform)
        return mel_spec

In [None]:
class CNNNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 4 conv blocks / flatten / linear / softmax
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * 8 * 16, 4)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_data):
        x = self.conv1(input_data)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        logits = self.linear(x)
        predictions = self.softmax(logits)
        return predictions

In [None]:

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size)
    return train_dataloader


def train_single_epoch(model, data_loader, loss_fn, optimiser):
    for input, target in data_loader:

        # calculate loss
        prediction = model(input)
        loss = loss_fn(prediction, target)

        # backpropagate error and update weights
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    print(f"loss: {loss.item()}")


def train(model, data_loader, loss_fn, optimiser, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_single_epoch(model, data_loader, loss_fn, optimiser)
        print("---------------------------")
    print("Finished training")

In [None]:
mel_spectrogram = MelSpectrogram(
        sample_rate=44100,
        n_fft=1024,
        hop_length=512,
        n_mels=128
    )
mds = MusicDataset(TRAINING_DATA_PATH,LABELS_PATH, mel_spectrogram)

train_dataloader = create_data_loader(mds, BATCH_SIZE)

cnn = CNNNetwork()
print(cnn)

loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(cnn.parameters(),lr=LEARNING_RATE)

train(cnn, train_dataloader, loss_fn, optimiser, EPOCHS)

torch.save(cnn.state_dict(), "/model")