# NLP Genre Classifier

## Imports

In [15]:
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
from collections import Counter
import time

from mxl_tokenizer import MusicXML_to_tokens

## Data Loader

In [16]:
class MusicXMLDataset(Dataset):
    def __init__(self, csv_path, vocab=None, max_len=512):
        self.entries = []
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                for i in range(2):
                    if ("./mxl/0/" + str(i) + "/") in row['mxl']: # DEBUG: use a subset for now
                        self.entries.append(row)
        
        # Enumerate unique genres from the "primary_genre" field.
        unique_genres = set(entry['primary_genre'] for entry in self.entries)
        print("PRIMARY GENRES:", unique_genres)
        self.genre_to_idx = {genre: idx for idx, genre in enumerate(sorted(unique_genres))}

        self.max_len = max_len
        # Build vocabulary if not provided
        if vocab is None:
            self.vocab = self.build_vocab()
        else:
            self.vocab = vocab

    def build_vocab(self):
        counter = Counter()
        # For each MusicXML file, get tokens using the provided tokenizer.
        for entry in self.entries:
            path = entry['mxl'].replace("./mxl/", "../data/mxl/")
            try:
                tokens = MusicXML_to_tokens(path)
            except Exception as e:
                print(f"Error tokenizing {path}: {e}")
                tokens = []
            counter.update(tokens)
        # Create vocab with special tokens
        vocab = {'<PAD>': 0, '<UNK>': 1, '<CLS>': 2}
        for token, _ in counter.items():
            if token not in vocab:
                vocab[token] = len(vocab)
        return vocab

    def tokenize_and_pad(self, mxml_path):
        try:
            tokens = MusicXML_to_tokens(mxml_path)
        except Exception as e:
            print(f"Error tokenizing {mxml_path}: {e}")
            tokens = []
        # Prepend a <CLS> token for classification
        tokens = ['<CLS>'] + tokens
        token_ids = [self.vocab.get(tok, self.vocab['<UNK>']) for tok in tokens]
        # Truncate or pad to max_len
        if len(token_ids) > self.max_len:
            token_ids = token_ids[:self.max_len]
        else:
            token_ids = token_ids + [self.vocab['<PAD>']] * (self.max_len - len(token_ids))
        return torch.tensor(token_ids, dtype=torch.long)

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        mxml_path = entry['mxl'].replace("./mxl/", "../data/mxl/")
        # Tokenize using your custom tokenizer
        token_ids = self.tokenize_and_pad(mxml_path)
        
        # Convert genre string into an integer label using the mapping.
        genre_str = entry['primary_genre']
        genre = self.genre_to_idx[genre_str]
        
        return token_ids, genre

## Define Positional Encoding

In [17]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            # if odd, handle last column
            pe[:, 1::2] = torch.cos(position * div_term[:pe[:, 1::2].shape[1]])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return x

## Define Model

In [18]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, num_classes=10, max_len=512, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.dropout = nn.Dropout(dropout)
        # Classifier head: you can use the <CLS> token embedding or a pooling over sequence outputs
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, src):
        # src shape: (batch_size, seq_len)
        embedded = self.embedding(src)  # (batch_size, seq_len, d_model)
        embedded = self.pos_encoder(embedded)
        # PyTorch transformer expects shape: (seq_len, batch_size, d_model)
        embedded = embedded.transpose(0, 1)
        transformer_output = self.transformer_encoder(embedded)  # (seq_len, batch_size, d_model)
        # Take the output corresponding to the <CLS> token (first token)
        cls_output = transformer_output[0]  # (batch_size, d_model)
        cls_output = self.dropout(cls_output)
        logits = self.fc(cls_output)
        return logits


## Train Model

In [19]:
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.
    for token_ids, labels in dataloader:
        token_ids, labels = token_ids.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(token_ids)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.
    correct = 0
    total = 0
    with torch.no_grad():
        for token_ids, labels in dataloader:
            token_ids, labels = token_ids.to(device), labels.to(device)
            logits = model(token_ids)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

## Test Model

In [20]:
# Testing function
def test_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for features, labels in dataloader:
            outputs = model(features)
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    accuracy = 100 * correct / total if total > 0 else 0
    print(f"Test Accuracy: {accuracy:.2f}%")

## Run

In [21]:
if __name__ == '__main__':
    # Hyperparameters
    csv_path = 'dataset.csv'
    max_len = 512
    batch_size = 32
    num_classes = 7  # Adjust according to your dataset
    d_model = 512
    nhead = 8
    num_layers = 6
    num_epochs = 100
    learning_rate = 1e-4

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create dataset and split into training/validation sets
    dataset = MusicXMLDataset(csv_path, max_len=max_len)
    vocab_size = len(dataset.vocab)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Instantiate model, optimizer, and loss function
    model = TransformerClassifier(vocab_size, d_model=d_model, nhead=nhead, 
                                  num_layers=num_layers, num_classes=num_classes, 
                                  max_len=max_len).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(num_epochs):
        start_time = time.time()  # Start timer for the epoch
    
        train_loss = train_model(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)
        
        epoch_time = time.time() - start_time  # Compute elapsed time for the epoch
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")

PRIMARY GENRES: {'Soundtrack', 'Religious', 'Rock & Metal', 'Electronic & Dance', 'Classical', 'Folk/World', 'Pop'}
Error tokenizing ../data/mxl/0/1/Qma2FqBY8JAg9sd28YkMptu2dkHgDzJPpeK5Xq4vtxs6x7.mxl: In part (Church Organ, Staff), measure (24): Cannot convert inexpressible durations to MusicXML.
Error tokenizing ../data/mxl/0/1/Qma2qXH7BTNA9mf19TPCYLmsP5sZBDxJ3SzdNi7oQVdqyT.mxl: Cannot insert None into a tag.
Error tokenizing ../data/mxl/0/1/Qma2M87ZgTYMGXmzKucfA1RN2gggRWdQu5JFVj3R84QvpF.mxl: In part (Part_1), measure (76): Cannot convert "2048th" duration to MusicXML (too short).




Error tokenizing ../data/mxl/0/1/Qma2M87ZgTYMGXmzKucfA1RN2gggRWdQu5JFVj3R84QvpF.mxl: In part (Part_1), measure (76): Cannot convert "2048th" duration to MusicXML (too short).
Error tokenizing ../data/mxl/0/1/Qma2FqBY8JAg9sd28YkMptu2dkHgDzJPpeK5Xq4vtxs6x7.mxl: In part (Church Organ, Staff), measure (24): Cannot convert inexpressible durations to MusicXML.
Error tokenizing ../data/mxl/0/1/Qma2qXH7BTNA9mf19TPCYLmsP5sZBDxJ3SzdNi7oQVdqyT.mxl: Cannot insert None into a tag.
Epoch 1: Train Loss = 1.9729, Val Loss = 1.3352, Val Acc = 0.6429
Error tokenizing ../data/mxl/0/1/Qma2M87ZgTYMGXmzKucfA1RN2gggRWdQu5JFVj3R84QvpF.mxl: In part (Part_1), measure (76): Cannot convert "2048th" duration to MusicXML (too short).
Error tokenizing ../data/mxl/0/1/Qma2FqBY8JAg9sd28YkMptu2dkHgDzJPpeK5Xq4vtxs6x7.mxl: In part (Church Organ, Staff), measure (24): Cannot convert inexpressible durations to MusicXML.
Error tokenizing ../data/mxl/0/1/Qma2qXH7BTNA9mf19TPCYLmsP5sZBDxJ3SzdNi7oQVdqyT.mxl: Cannot insert None 

KeyboardInterrupt: 