In [None]:
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import torch
import json

In [None]:
# Variables
num_epochs = 50
device = 'cpu'
modelName = 'model.pth'
historyFileName = 'history.json'

In [None]:
# CRNN model architecture
# The model uses 3 CNN layers and 2 LSTM layers

class MusicGenreCRNN(nn.Module):
    def __init__(self, num_genres, input_channels):
        super(MusicGenreCRNN, self).__init__()
        
        # CNN Layers -used for feature extraction from spectrogram
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), 
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), 
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), 
        )
        
        self.rnn_input_size = 256 * 16  
        
        # RNN layers - used for temporal analysis of features
        self.rnn = nn.LSTM(
            input_size=self.rnn_input_size,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        self.attention = nn.Sequential(
            nn.Linear(256, 64), 
            nn.Tanh(),
            nn.Linear(64, 1),
            nn.Softmax(dim=1)
        )
        
        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_genres)
        )
    
    def forward(self, x):
        batch_size = x.size(0)
        x = self.cnn_layers(x)
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(batch_size, -1, self.rnn_input_size)
        
        rnn_out, _ = self.rnn(x)
        
        attention_weights = self.attention(rnn_out)
        x = torch.sum(attention_weights * rnn_out, dim=1)
        
        x = self.classifier(x)
        return x, attention_weights

In [None]:
# We trained the model on our own laptop, so we had to use CPU, since we don't have a dedicated GPU
# If you however have a GPU, set devide='cuda'. GPU is generally faster than CPU
# Adjust the num_epochs as well if you want. More epochs = longer training time, but might provide better results

def train_model(model, train_loader, val_loader, num_epochs, device):
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'epoch': []
    }
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    

    # Training phase - goes through each epoch
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device).squeeze()
            
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}, '
                      f'Acc: {100.*correct/total:.2f}%')
        
        epoch_train_loss = train_loss/len(train_loader)
        epoch_train_acc = 100.*correct/total
        
        # Validation phase 
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device).squeeze()
                outputs, _ = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        epoch_val_loss = val_loss/len(val_loader)
        epoch_val_acc = 100.*val_correct/val_total
        
        # Stores the training and validation metrics in history dictionary - very useful for plotting later
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)
        history['epoch'].append(epoch)
        
        # Prints metrics for each epoch - lets you follow the training in real time
        print(f'\nEpoch: {epoch}')
        print(f'Training Loss: {epoch_train_loss:.4f}')
        print(f'Training Accuracy: {epoch_train_acc:.2f}%')
        print(f'Validation Loss: {epoch_val_loss:.4f}')
        print(f'Validation Accuracy: {epoch_val_acc:.2f}%\n')
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history,
        }
        torch.save(checkpoint, f'revised_checkpoint_epoch_{epoch}.pth')
        
        # Save the model
        # IMPORTANT: Rename the model to whatever you want
        # we used the following for simple overview: {num_epochs}ep{songDuration}s_model.pth
        if epoch == 0 or epoch_val_loss < min(history['val_loss'][:-1]):
            torch.save(checkpoint, modelName)
    
    # Save history to JSON file
    # Again, rename it however you want
    import json
    with open(historyFileName, 'w') as f:
        json.dump(history, f)
    
    return history

# Function to load and plot metrics
def plot_training_history(history_path=historyFileName):
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    plt.style.use('seaborn')
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot accuracy
    ax1.plot(history['epoch'], history['train_acc'], label='Training Accuracy', color='#1f77b4')
    ax1.plot(history['epoch'], history['val_acc'], label='Validation Accuracy', color='#ff7f0e')
    ax1.set_title('Model Accuracy over Epochs', pad=20)
    ax1.set_ylabel('Accuracy (%)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot loss
    ax2.plot(history['epoch'], history['train_loss'], label='Training Loss', color='#1f77b4')
    ax2.plot(history['epoch'], history['val_loss'], label='Validation Loss', color='#ff7f0e')
    ax2.set_title('Model Loss over Epochs', pad=20)
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# There are also options to implement continuing trainnig the model from a checkpoint
# but we haven't implemented it as there was no need it

In [None]:
# This function gets the files for training and also testing
def get_files_and_labels(base_path):
    files = []
    labels = []
    
    genres = sorted(os.listdir(base_path))
    
    # Create genre to label mapping (e.g., 'blues': 0, 'classical': 1, etc.)
    # This is useful for converting genre names to integers for training
    genre_to_label = {genre: idx for idx, genre in enumerate(genres)}
    
    for genre in genres:
        genre_path = os.path.join(base_path, genre)
        if os.path.isdir(genre_path):
            for file in os.listdir(genre_path):
                if file.endswith(('.wav')):
                    files.append(os.path.join(genre_path, file))
                    labels.append(genre_to_label[genre])
    
    return files, labels, genre_to_label


def prepare_dataset():
    # Paths to training, vlaidation and test data folders
    train_path = "./data/train_files"  
    val_path = "./data/validation_files"      
    test_path = "./data/test_files"    
    
    train_files, train_labels, genre_mapping = get_files_and_labels(train_path)
    val_files, val_labels, _ = get_files_and_labels(val_path)
    test_files, test_labels, _ = get_files_and_labels(test_path)
    
    print("\nDataset statistics:")
    print(f"Training files: {len(train_files)}")
    print(f"Validation files: {len(val_files)}")
    print(f"Test files: {len(test_files)}")
    
    # Shows the distribution of genres and files
    print("\nTraining set genre distribution:")
    unique_labels, counts = np.unique(train_labels, return_counts=True)
    for label, count in zip(unique_labels, counts):
        genre = list(genre_mapping.keys())[list(genre_mapping.values()).index(label)]
        print(f"{genre}: {count} files")
    
    return (train_files, train_labels), (val_files, val_labels), (test_files, test_labels), genre_mapping

(train_files, train_labels), (val_files, val_labels), (test_files, test_labels), genre_mapping = prepare_dataset()

In [None]:
# Uses Librosa to load the music files and convert them to mel spectrograms
# We also add some random augmentations to the data to make the model more robust

class MusicDataset(Dataset):
    def __init__(self, audio_files, labels):
        self.audio_files = audio_files
        self.labels = labels
        
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        try:
            audio, sr = librosa.load(self.audio_files[idx], duration=15, sr=22050)

            # Random augmentations
            if np.random.random() > 0.5:
                rate = np.random.uniform(0.8, 1.2)
                audio = librosa.effects.time_stretch(audio, rate=rate)
            
            if np.random.random() > 0.5:
                steps = np.random.randint(-2, 3)
                audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=steps)
                
            # MEL spectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=audio,
                sr=sr,
                n_mels=128,
                n_fft=2048,
                hop_length=512
            )
            
            # Log scale
            mel_spec = librosa.power_to_db(mel_spec)
            
            target_length = 1292
            if mel_spec.shape[1] > target_length:
                mel_spec = mel_spec[:, :target_length]
            elif mel_spec.shape[1] < target_length:
                pad_width = target_length - mel_spec.shape[1]
                mel_spec = np.pad(mel_spec, ((0, 0), (0, pad_width)), mode='constant')
            


            mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)
            
            # Converts to PyTorch tensor
            mel_spec = torch.FloatTensor(mel_spec).unsqueeze(0)
            label = torch.LongTensor([self.labels[idx]])
            
            return mel_spec, label
            
        except Exception as e:
            print(f"Error processing index {idx}, file {self.audio_files[idx]}: {str(e)}")
            raise

In [None]:
# Main method, entrypoint for the script
# Loads the training and validation datasets, creates loaders and initializes the model training

def main():
    train_dataset = MusicDataset(train_files, train_labels)
    val_dataset = MusicDataset(val_files, val_labels)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=16,  # Reduced from 32
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=16,  # Reduced from 32
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    # Initializs the model
    model = MusicGenreCRNN(
        num_genres=10,
        input_channels=1
    )
    
    # Load single batch before training to check if everything works
    print("Testing data loader...")
    try:
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            print(f"Successfully loaded batch {batch_idx}")
            if batch_idx == 0:  # Just test one batch
                break
    except Exception as e:
        print(f"Error loading batch: {str(e)}")
        raise
    
    history = train_model(model, train_loader, val_loader, num_epochs, device)
    plot_training_history(historyFileName)

if __name__ == "__main__":
    main()