# LSTMusic with key and metro

In [1]:
import pandas as pd
import re
from music21 import *
import numpy as np

import time
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Data Formatting

In [2]:
f = open("../data/jiggs.txt", "r")
raw_input = f.read()

In [3]:
class Repertoir():
    def __init__(self, path):
        self.path = path
        f = open(path, "r")
        self.string = f.read()
        self.handler = abcFormat.ABCHandler()
        self.handler.process(self.string)
        self.songs_handlers = self.handler.splitByReferenceNumber()
        self.songs = {}
        self.__process()
    
    def __str__(self):
        return self.string
    
    
    def __process(self):
        for ref_number, handler in self.songs_handlers.items():
            self.songs[ref_number] = Song(handler)
            
    def get_part_vocab(self):
        tokens = []
        for ref_number, song in self.songs.items():
            tokens+= song.part
        tokens = list(set(tokens))            
        return tokens
    
    def get_metadata_vocab(self, key):
        tokens = []
        for ref_number, song in self.songs.items():
            tokens+= [song.metadata[key]]
        tokens = list(set(tokens))            
        return tokens    

In [4]:
class Song():
    def __init__(self, handler):
        self.handler = handler
        self.metadata = {
            'X':1,
            'T':'Unknown',
            'S':'Unknown',
            'M':'none',
            'L':'',
            'Q':'',
            'K':''
        }
        self.part = []
        self.__process()
        
    def __process(self):
        for token in self.handler.tokens:
            meta_data_ended=False
            if isinstance(token, abcFormat.ABCMetadata):
                if token.tag in self.metadata.keys():
                    if self.metadata[token.tag]=='' or not meta_data_ended:
                        self.metadata[token.tag] = token.data
                else:
                    self.metadata[token.tag] = token.data
            elif isinstance(token, abcFormat.ABCNote ) or isinstance(token, abcFormat.ABCBar):
                meta_data_ended = True
                self.part.append(token.src)
    
    def __str__(self):
        return self.to_abc()
    
    def to_abc(self):
        output = ''
        for key, value in self.metadata.items():
            output+= key+':'+value+"\n"
        for note in self.part:
            output+=note
        return output

In [5]:
def generate_char_idx_mappings(vocab):
    char2idx = {u:i for i, u in enumerate(vocab)}
    idx2char = np.array(vocab)
    return char2idx, idx2char

In [6]:
def get_input_tensors(part, k, m, part_char2idx, k_char2idx, m_char2idx):
    part_tensor = torch.tensor([part_char2idx[note] for note in part[0:-1]], dtype=torch.long)
    k_tensor = torch.tensor([k_char2idx[k] for note in part[0:-1]], dtype=torch.long)
    m_tensor = torch.tensor([m_char2idx[m] for note in part[0:-1]], dtype=torch.long)
    return part_tensor, k_tensor, m_tensor,

def get_target_tensor(part, part_char2idx):
    target_tensor = torch.tensor([part_char2idx[note] for note in part[1:]], dtype=torch.long)
    return target_tensor

In [7]:
rep = Repertoir('../data/jiggs.txt')

In [8]:
part_vocab = rep.get_part_vocab()
m_vocab = rep.get_metadata_vocab('M')
k_vocab = rep.get_metadata_vocab('K')

In [9]:
part_char2idx, part_idx2char = generate_char_idx_mappings(part_vocab)
k_char2idx, k_idx2char = generate_char_idx_mappings(k_vocab)
m_char2idx, m_idx2char = generate_char_idx_mappings(m_vocab)

## LSTMusic

In [10]:
class LSTMusic(nn.Module):

    def __init__(self, part_embedding_dim, k_embedding_dim, m_embedding_dim, lstm_dim, part_vocab_size, k_vocab_size, m_vocab_size):
        super(LSTMusic, self).__init__()
        self.lstm_dim = lstm_dim
        
        self.part_embeddings = nn.Embedding(part_vocab_size, part_embedding_dim)
        self.k_embeddings = nn.Embedding(k_vocab_size, k_embedding_dim)
        self.m_embeddings = nn.Embedding(m_vocab_size, m_embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(m_embedding_dim+k_embedding_dim+part_embedding_dim, lstm_dim)

        # The linear layer that maps from hidden state space to tag space
        self.dense = nn.Linear(lstm_dim, part_vocab_size)

    def forward(self, part, k, m, prev_state):
        part_embeds = self.part_embeddings(part)
        k_embeds = self.k_embeddings(k)
        m_embeds = self.m_embeddings(m)
        combined = torch.cat((m_embeds, k_embeds, part_embeds), 1)
        lstm_out, state = self.lstm(combined.view(len(part), 1, -1), prev_state)
        output = self.dense(lstm_out.view(len(part), -1))
        return output, state
    
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_dim),
                torch.zeros(1, batch_size, self.lstm_dim))

In [11]:
def generate(model, part, k, m, length):
    model.eval()
    notes = []
    with torch.no_grad():  # no need to track history in sampling
        state_h, state_c = model.zero_state(1)
        for note in part:
            note_tensor, k_tensor, m_tensor = get_input_tensors([note,' '], k, m, part_char2idx, k_char2idx, m_char2idx)
            output, (state_h, state_c) = model(note_tensor, k_tensor, m_tensor, (state_h, state_c))
        print(output)
        _, top_ix = torch.topk(output[0], k=5)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        notes.append(part_idx2char[choice])
    for _ in range(length):
        note_tensor = torch.tensor([choice])
        output, (state_h, state_c) = model(note_tensor, k_tensor, m_tensor, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=5)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        notes.append(part_idx2char[choice])

    abc = "M:{}\nK:{}\n".format(m,k) + ''.join(part) + ''.join(notes)
    print(abc)
    return abc, notes

In [12]:
def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [13]:
part_embedding_dim = 64
k_embedding_dim = 32
m_embedding_dim = 32
lstm_dim = 128
part_vocab_size = len(part_vocab)
k_vocab_size = len(k_vocab)
m_vocab_size = len(m_vocab)

nb_epoch = 1000
lr = 0.01
max_norm = 5

start_part = rep.songs[1].part[0:10]
start_k = rep.songs[1].metadata['K']
start_m = rep.songs[1].metadata['M']
length = 100

In [None]:
model = LSTMusic(part_embedding_dim, k_embedding_dim, m_embedding_dim, lstm_dim, part_vocab_size, k_vocab_size, m_vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

start = time.time()

for epoch in range(nb_epoch):
    print('========= Epoch {} out of {} ({}%) ========='.format(epoch+1, nb_epoch, (epoch+1)/nb_epoch*100))
    nb_iters = len(rep.songs)
    iteration = 0
    print_every = nb_iters//10
    state_h, state_c = model.zero_state(1)

    for ref_num, song in rep.songs.items():
        
        iteration+=1
        
        model.train()
        
        optimizer.zero_grad()
        
        part, k, m  = get_input_tensors(song.part, song.metadata['K'], song.metadata['M'], part_char2idx, k_char2idx, m_char2idx)
        target = get_target_tensor(song.part, part_char2idx)
        
        output, (state_h, state_c) = model(part, k, m, (state_h, state_c))
        
        state_h = state_h.detach()
        state_c = state_c.detach()
        
        loss = criterion(output, target)
        
        loss.backward()
        
        _ = torch.nn.utils.clip_grad_norm_(
                model.parameters(), max_norm)
        
        optimizer.step()
        
        if iteration % print_every == 0:
            print('%s (%d %d%%) %.4f' % (timeSince(start), iteration, iteration / nb_iters * 100, loss.item()))
        torch.save(model.state_dict(), '../models/model.pt')
    generate(model, start_part, start_k, start_m, length)    
    torch.save(model.state_dict(),'../models/model-{}.pth'.format(epoch))

0m 3s (34 10%) 4.4567
0m 5s (68 20%) 3.6046
0m 8s (102 30%) 3.8612
0m 10s (136 40%) 2.8873
0m 12s (170 50%) 2.7205
0m 15s (204 60%) 3.4639
0m 17s (238 70%) 3.4653
0m 20s (272 80%) 2.6637
0m 22s (306 90%) 2.9416
0m 25s (340 100%) 3.2469
tensor([[-6.2083, -6.3867, -6.8348,  ..., -8.6024, -6.1298, -8.2999]])
M:6/8
K:D
f|"A"eccc2f|"A"ecD/2"E7"D2"F#m"e"D/f+"a2"B"e"E7"d2"A7/e"g2"Gm"g2"Em"b2=F"B7"A3"C"G3"Gm"f2"D/a"A2"F#m"a2"B7"B3"A"f2"D7"f2"Em""C"e/2"Bb"_b2^G2"Am"c"D"d3/2g3/2"Em"g3/2"G""3"B"e"d"D"A2"F"c3"A"c3"C""e"g2a3/2"Dm""f"A2"D"D"Bb"B,"(D7)"B,2"Am"c2"G7"B2"C"G"Bb7"D"Bb"d2"D""Bm"F3"C"C6"D"g"Em"f3"D"g"F#7"e"A7"F2"E7/b"e"Bb"f"G"F3"A7"g2"D""Bm"F3"G"b3[FAd]"Em""C"e/2_B"D/f+"a2"F#m"e"f#"g" ""D"fF3"Eb"A"F"d3"F"c'2"Dm"G"Bm7"f"A"b"3"B"Dm"g"A7"G3"F"c3"C"a"Eb"A"Gm"G3=A^e3"G/b"g"E/dim"g2"Bm"d"Dm"F3"Dm"=f"D"[F3A3d3]"Em"g/2"A7"e3"D/f+"A"g"B"D7/a"c2"D7"e2"F#m"e"G""d"B3"E7"g3"G""c"G"B7"F2"F"B3b/2^e3"A7"E3"B7"B3f3"E7"=B
0m 27s (34 10%) 3.1889
0m 30s (68 20%) 2.3193
0m 32s (102 30%) 2.5637
0m 35s (136 40%)