In [30]:
import functools
import json
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import sys
sys.path.append("..")
from melolib.notation import DurationClass, Rest, Duration
from melolib.music import Note

In [31]:
duration_repr_to_class = {
    "w": DurationClass.Whole,
    "h": DurationClass.Half,
    "q": DurationClass.Quarter,
    "8": DurationClass.Eighth,
    "16": DurationClass.Sixteenth,
    "32": DurationClass.ThirtySecond,
    "64": DurationClass.SixtyFourth,
}

def parse_duration(duration):
    r = duration.rstrip(".")
    duration_class = duration_repr_to_class[r]
    num_dots = len(duration) - len(r)
    return Duration(duration_class, num_dots)

def parse_note(note):
    note, octave = note.split("/")
    return Note.from_name(note, int(octave))

def duration_to_tensor(duration):
    d = parse_duration(duration)
    du = torch.zeros(len(DurationClass))
    du[d.duration_class.value] = 1
    do = torch.zeros(3)
    do[d.dots] = 1
    return torch.cat((du, do))

def note_to_tensor(note, is_rest):
    if is_rest:
        return torch.zeros(22)

    n = parse_note(note)
    key = torch.zeros(12)
    octave = torch.zeros(10)
    key[Note.chromatic_sharps.index(n.get_name()[0])] = 1
    octave[n.get_name()[1]] = 1
    return torch.cat((key, octave))

def notation_to_tensor(notation):
    note = note_to_tensor(notation["key"], "r" in notation["duration"])
    duration = duration_to_tensor(notation["duration"].replace("r", ""))
    return torch.cat((note, duration))

class MidiMelodyDataset(Dataset):
    def __init__(self, filename="../../out/melody.json"):
        with open(filename) as f:
            songs = json.load(f)

        melodies = []
        for song in songs:
            for track in song:
                d = track["data"]
                k = track["key_signature"]
                l = len(d) - len(d) % 25
                for i in range(0, len(d), 25):
                    line = d[25 * i: 25 * (i+1)]
                    key_tensor = torch.zeros(12)
                    key_tensor[Note.chromatic_sharps.index(k[0][0])] = 1
                    line_tensors = [torch.cat((key_tensor, notation_to_tensor(x))) for x in line]
                    if len(line) == 25:
                        melodies.append((k, line_tensors[1:], line_tensors[:-1]))
        
        self.data = []
        for (k, x, y) in melodies:
            self.data.append((torch.stack(list(x)), torch.stack(list(y))))        
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = MidiMelodyDataset()
train_dataset, test_dataset = random_split(dataset, [0.80, 0.20])
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

OUTPUT_SIZE = 12 + 12 + 10 + len(DurationClass) + 3
INPUT_SIZE = OUTPUT_SIZE

print(f"Input size: {INPUT_SIZE}")

Input size: 44


In [32]:
class GRUNetwork(torch.nn.Module):
    def __init__(self, input_size, output_size, hidden_size, state_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.state_size = hidden_size
        self.output_size = output_size

        self.gru = torch.nn.GRU(self.input_size, self.state_size, num_layers=1, batch_first=True)
        self.h2h = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.h2o = torch.nn.Linear(self.hidden_size, self.output_size)

    def forward(self, i: torch.Tensor, state: torch.Tensor):
        h, s = self.gru(i, state)
        h2 = self.h2h(torch.relu(h))
        o = self.h2o(torch.relu(h2))

        split_outputs = torch.split(o, [12, 12, 10, 7, 3], dim=-1)
        o = [F.log_softmax(o, dim=-1) for o in split_outputs]
        o = torch.cat(o, dim=-1)
        return o, s

    def init_hidden(self, batch_size=None):
        if batch_size is None:
            return torch.zeros((1, self.hidden_size))
        return torch.zeros((1, batch_size, self.hidden_size))

In [33]:
def train_loop(dataloader, model, loss, optimizer, epoch):
    model.train()
    total_loss = 0

    for batch_idx, (x, y) in enumerate(dataloader):
        state = model.init_hidden(x.size(dim=0))
        pred, state = model(x, state)

        pred = torch.reshape(pred, (-1, pred.size(-1)))
        y = torch.reshape(y, (-1, y.size(-1)))
        
        split_predictions = torch.split(pred, [12, 12, 10, 7, 3], dim=-1)
        split_y= torch.split(y, [12, 12, 10, 7, 3], dim=-1)
        
        cost = sum(loss(p, torch.argmax(y_, dim=-1)) for p, y_ in zip(split_predictions, split_y))
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

        batch_loss = cost.item()
        total_loss += batch_loss

        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx} loss: {batch_loss}")

    return total_loss


def test_loop(dataloader, model, loss, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(dataloader):
            state = model.init_hidden(x.size(dim=0))
            pred, state = model(x, state)
            
            pred = torch.reshape(pred, (-1, pred.size(-1)))
            y = torch.reshape(y, (-1, y.size(-1)))
            
            split_predictions = torch.split(pred, [12, 12, 10, 7, 3], dim=-1)
            split_y= torch.split(y, [12, 12, 10, 7, 3], dim=-1)
            cost = sum(loss(p, torch.argmax(y_, dim=-1)) for p, y_ in zip(split_predictions, split_y))

            batch_loss = cost.item()
            test_loss += batch_loss
    return test_loss


In [None]:
HIDDEN_SIZE = 10
STATE_SIZE = 8
    
def main():
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    print(f"Using {device} device")

    model = GRUNetwork(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE, STATE_SIZE).to(device)
    loss = torch.nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0005)
    epochs = 200
    for epoch in range(epochs):
        train_loss = train_loop(train_dataloader, model, loss, optimizer, epoch)
        test_loss = test_loop(test_dataloader, model, loss, epoch)

        print("---------------")
        print(f"Epoch: {epoch+1}")
        print("Loss in training: ", train_loss)
        print("Loss in test:", test_loss)
        print("---------------\n")
    torch.save(model.state_dict(), "../../out/rnn_melody.pth")

main()

In [35]:
def sandbox():
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")
    model = GRUNetwork(INPUT_SIZE, OUTPUT_SIZE, HIDDEN_SIZE, STATE_SIZE).to(device)
    model.load_state_dict(torch.load("../../out/rnn_melody.pth"))
    model.eval()

    with torch.no_grad():
        while True:
            try:
                key = input("Enter key: ").strip()
                notations = input("Enter notation list: ").strip()
            except KeyboardInterrupt:
                break

            if key == "" or notations == "":
                break

            key_tensor = torch.zeros(12)
            key_tensor[Note.chromatic_sharps.index(key)] = 1
            notations = [notation.split("-") for notation in notations.split(" ")]
            input_tensors = []
            for note, duration in notations:
                is_rest = "r" in duration
                notation_tensor = notation_to_tensor({"key": note, "duration": duration})
                input_tensors.append(torch.cat((key_tensor, notation_tensor)))
                
            input_tensors.extend([input_tensors[-1]] * (24 - len(input_tensors)))
            input_sequence_tensors = torch.stack(input_tensors)
    
            state = model.init_hidden()
            pred, _ = model(input_sequence_tensors, state)
            prob = torch.exp(pred)
            np.set_printoptions(threshold=10_000)
            torch.set_printoptions(profile="full")
            print(prob, torch.sum(prob))
sandbox()