In [None]:
import functools
import json
import numpy as np
import torch
import torch.nn
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:

class UltimateGuitarSongDataset(Dataset):
    def __init__(self, filename="../out/chord_progressions.json"):
        with open(filename) as f:
            songs = json.load(f)
            assert isinstance(songs, list)

        songs = [song[:8] for song in songs if len(song) > 8]
        
        self.unique_chords = sorted(list(functools.reduce(lambda acc, x: acc | set(x), songs, set())))

        self.unique_chords_to_tensors = {}
        for i, chord in enumerate(self.unique_chords):
            tensor = torch.zeros(len(self.unique_chords), dtype=torch.long)
            tensor[i] = 1
            self.unique_chords_to_tensors[chord] = tensor

        self.data = []
        for song in songs:
            x = (self.unique_chords_to_tensors[chord] for chord in song[:-1])
            y = (self.unique_chords_to_tensors[chord] for chord in song[1:])
            self.data.append((torch.stack(list(x)), torch.stack(list(y))))        
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = UltimateGuitarSongDataset()
_, train_dataset, test_dataset = random_split(dataset, [0.0, 0.80, 0.20])
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

INPUT_SIZE = OUTPUT_SIZE = len(dataset.unique_chords)
print(f"Input size: {INPUT_SIZE}")

In [None]:
class RNNetwork(torch.nn.Module):
    def __init__(self, input_size, output_size, hidden_size, state_size):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.state_size = state_size

        self.i2s = torch.nn.Linear(self.input_size + self.state_size, self.state_size)
        self.i2h = torch.nn.Linear(self.input_size + self.state_size, self.hidden_size)
        self.h2h = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.h2o = torch.nn.Linear(self.hidden_size, self.output_size)
        self.dropout = torch.nn.Dropout(0.10)
        self.softmax = torch.nn.LogSoftmax(dim=0)

    def forward(self, i: torch.Tensor, state: torch.Tensor):
        i_ = torch.cat((i, state))
        s = self.i2s(i_)
        h = self.i2h(i_)
        h2 = self.h2h(torch.relu(h))
        o = self.h2o(torch.relu(h2))
        o = self.dropout(o)
        return self.softmax(o), s

    def init_hidden(self):
        return torch.zeros(self.state_size)

In [None]:
def train_loop(dataloader, model, loss, optimizer, epoch):
    model.train()
    total_loss = 0

    for batch_idx, (batched_x, batched_y) in enumerate(dataloader):
        cost = 0
        for x, y in zip(batched_x, batched_y):
            state = model.init_hidden()
            for i, (x_, y_) in enumerate(zip(x, y)):
                pred, state = model(x_, state)
                cost += loss(pred, torch.argmax(y_))
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

        batch_loss = cost.item() / (batched_x.size(0) * batched_x.size(1) * batched_x.size(2))
        total_loss += batch_loss

        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx} loss: {batch_loss}")

    return total_loss


def test_loop(dataloader, model, loss, epoch):
    model.eval()
    test_loss, correct = 0, 0
    total_items = 0
    with torch.no_grad():
        for batch_idx, (batched_x, batched_y) in enumerate(dataloader):
            batch_loss = 0
            for x, y in zip(batched_x, batched_y):
                state = model.init_hidden()
                for x_, y_ in zip(x, y):
                    pred, state = model(x_, state)
                    batch_loss += loss(pred, torch.argmax(y_))

                    # Accuracy count
                    _, indices = torch.topk(pred, k=3)
                    if torch.argmax(y_).item() in indices:
                        correct += 1
            batch_loss /= batched_x.size(0) * batched_x.size(1) * batched_x.size(2)
            test_loss += batch_loss
            total_items += batched_x.size(0) * batched_x.size(1)
        accuracy = 100 * correct / total_items
        print(f"Accuracy: {accuracy:.2f}")
    return test_loss.item()


In [None]:
HIDDEN_SIZE = 30
STATE_SIZE = 20

def main():
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using {device} device")

    model = RNNetwork(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE, STATE_SIZE).to(device)
    loss = torch.nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
    epochs = 10
    for epoch in range(epochs):
        train_loss = train_loop(train_dataloader, model, loss, optimizer, epoch)
        test_loss = test_loop(test_dataloader, model, loss, epoch)

        print("---------------")
        print(f"Epoch: {epoch+1}")
        print("Loss in training: ", train_loss)
        print("Loss in test:", test_loss)
        print("---------------\n")
        
    torch.save(model.state_dict(), "../out/rnn_chord_progressions_classification.pth")

main()

In [None]:
def sandbox():
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using {device} device")
    model = RNNetwork(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE, STATE_SIZE).to(device)
    model.load_state_dict(torch.load("../out/rnn_chord_progressions_classification.pth"))
    model.eval()

    with torch.no_grad():
        while True:
            try:
                chord_progression = input("Enter chord progression: ").strip()
            except KeyboardInterrupt:
                break

            if chord_progression == "":
                break

            chord_tensors = [dataset.unique_chords_to_tensors[chord] for chord in chord_progression.split()]

            state = model.init_hidden()
            pred = None
            for tensor in chord_tensors:
                pred, state = model(tensor, state)

            prob = torch.exp(pred)
            prob = prob / torch.sum(prob)
            for prob, idx in zip(*torch.topk(prob, k=5)):
                idx = idx.item()
                chord = dataset.unique_chords[idx]
                print(f"{chord}: {prob.item():.2f}")
            print("\nExpected: ", chord_progression + " " + dataset.unique_chords[torch.argmax(pred).item()])
sandbox()