# LSTMusic with key and metro

In [1]:
import pandas as pd
import re
from music21 import *
import numpy as np

import time
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Data Formatting

In [2]:
f = open("../data/jiggs.txt", "r")
raw_input = f.read()

In [3]:
class Repertoir():
    def __init__(self, path):
        self.path = path
        f = open(path, "r")
        self.string = f.read()
        self.handler = abcFormat.ABCHandler()
        self.handler.process(self.string)
        self.songs_handlers = self.handler.splitByReferenceNumber()
        self.songs = {}
        self.__process()
    
    def __str__(self):
        return self.string
    
    
    def __process(self):
        for ref_number, handler in self.songs_handlers.items():
            self.songs[ref_number] = Song(handler)
            
    def get_part_vocab(self):
        tokens = []
        for ref_number, song in self.songs.items():
            tokens+= song.part
        tokens = list(set(tokens))            
        return tokens
    
    def get_metadata_vocab(self, key):
        tokens = []
        for ref_number, song in self.songs.items():
            tokens+= [song.metadata[key]]
        tokens = list(set(tokens))            
        return tokens    

In [4]:
class Song():
    def __init__(self, handler):
        self.handler = handler
        self.metadata = {
            'X':1,
            'T':'Unknown',
            'S':'Unknown',
            'M':'none',
            'L':'',
            'Q':'',
            'K':''
        }
        self.part = []
        self.__process()
        
    def __process(self):
        for token in self.handler.tokens:
            meta_data_ended=False
            if isinstance(token, abcFormat.ABCMetadata):
                if token.tag in self.metadata.keys():
                    if self.metadata[token.tag]=='' or not meta_data_ended:
                        self.metadata[token.tag] = token.data
                else:
                    self.metadata[token.tag] = token.data
            elif isinstance(token, abcFormat.ABCNote ) or isinstance(token, abcFormat.ABCBar):
                meta_data_ended = True
                self.part.append(token.src)
    
    def __str__(self):
        return self.to_abc()
    
    def to_abc(self):
        output = ''
        for key, value in self.metadata.items():
            output+= key+':'+value+"\n"
        for note in self.part:
            output+=note
        return output

In [5]:
def generate_char_idx_mappings(vocab):
    char2idx = {u:i for i, u in enumerate(vocab)}
    idx2char = np.array(vocab)
    return char2idx, idx2char

In [6]:
def get_input_tensors(part, k, m, part_char2idx, k_char2idx, m_char2idx):
    part_tensor = torch.tensor([part_char2idx[note] for note in part[0:-1]], dtype=torch.long)
    k_tensor = torch.tensor([k_char2idx[k] for note in part[0:-1]], dtype=torch.long)
    m_tensor = torch.tensor([m_char2idx[m] for note in part[0:-1]], dtype=torch.long)
    return part_tensor, k_tensor, m_tensor,

def get_target_tensor(part, part_char2idx):
    target_tensor = torch.tensor([part_char2idx[note] for note in part[1:]], dtype=torch.long)
    return target_tensor

In [7]:
rep = Repertoir('../data/jiggs.txt')

In [8]:
part_vocab = rep.get_part_vocab()
m_vocab = rep.get_metadata_vocab('M')
k_vocab = rep.get_metadata_vocab('K')

In [9]:
part_char2idx, part_idx2char = generate_char_idx_mappings(part_vocab)
k_char2idx, k_idx2char = generate_char_idx_mappings(k_vocab)
m_char2idx, m_idx2char = generate_char_idx_mappings(m_vocab)

## LSTMusic

In [10]:
class LSTMusic(nn.Module):

    def __init__(self, part_embedding_dim, k_embedding_dim, m_embedding_dim, lstm_dim, part_vocab_size, k_vocab_size, m_vocab_size):
        super(LSTMusic, self).__init__()
        self.lstm_dim = lstm_dim
        
        self.part_embeddings = nn.Embedding(part_vocab_size, part_embedding_dim)
        self.k_embeddings = nn.Embedding(k_vocab_size, k_embedding_dim)
        self.m_embeddings = nn.Embedding(m_vocab_size, m_embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(m_embedding_dim+k_embedding_dim+part_embedding_dim, lstm_dim)

        # The linear layer that maps from hidden state space to tag space
        self.dense = nn.Linear(lstm_dim, part_vocab_size)

    def forward(self, part, k, m, prev_state):
        part_embeds = self.part_embeddings(part)
        k_embeds = self.k_embeddings(k)
        m_embeds = self.m_embeddings(m)
        combined = torch.cat((m_embeds, k_embeds, part_embeds), 1)
        lstm_out, state = self.lstm(combined.view(len(part), 1, -1), prev_state)
        output = self.dense(lstm_out.view(len(part), -1))
        return output, state
    
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_dim),
                torch.zeros(1, batch_size, self.lstm_dim))

In [28]:
def generate(model, part, k, m, length):
    model = model.eval()
    notes = []
    with torch.no_grad():  # no need to track history in sampling
        state_h, state_c = model.zero_state(1)
        for note in part:
            note_tensor, k_tensor, m_tensor = get_input_tensors([note,' '], k, m, part_char2idx, k_char2idx, m_char2idx)
            output, (state_h, state_c) = model(note_tensor, k_tensor, m_tensor, (state_h, state_c))
        _, top_ix = torch.topk(output[0], k=5)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        notes.append(part_idx2char[choice])
    for _ in range(length):
        note_tensor = torch.tensor([choice])
        output, (state_h, state_c) = model(note_tensor, k_tensor, m_tensor, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=5)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        notes.append(part_idx2char[choice])

    abc = "M:{}\nK:{}\n".format(m,k) + ''.join(part) + ''.join(notes)
    print(abc)
    return abc, notes

In [29]:
def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [76]:
min(600, round(1.6 * k_vocab_size ** .56))

6

In [80]:
part_vocab_size = len(part_vocab)
k_vocab_size = len(k_vocab)
m_vocab_size = len(m_vocab)

part_embedding_dim = min(600, round(1.6 * part_vocab_size ** .56))
k_embedding_dim = min(600, round(1.6 * k_vocab_size ** .56))
m_embedding_dim = min(600, round(1.6 * m_vocab_size ** .56))
lstm_dim = 512

nb_epoch = 10000
lr = 0.001
max_norm = 5

start_part = rep.songs[15].part[0:25]
start_k = rep.songs[15].metadata['K']
start_m = rep.songs[15].metadata['M']
length = 100

In [None]:
model = LSTMusic(part_embedding_dim, k_embedding_dim, m_embedding_dim, lstm_dim, part_vocab_size, k_vocab_size, m_vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

start = time.time()
epoch = 0
for epoch in range(1,nb_epoch+1):
    print('\n========= Epoch {} out of {} ({} %) =========\n'.format(epoch, nb_epoch, (epoch)/nb_epoch*100))
    nb_iters = len(rep.songs)
    iteration = 0
    print_every = nb_iters//10
    state_h, state_c = model.zero_state(1)
    songs_idxs = list(rep.songs.keys())
    
    for iteration in range(1, len(songs_idxs)+1):
        song = rep.songs[np.random.choice(songs_idxs)]
        
        model = model.train()
        
        optimizer.zero_grad()
        
        part, k, m  = get_input_tensors(song.part, song.metadata['K'], song.metadata['M'], part_char2idx, k_char2idx, m_char2idx)
        target = get_target_tensor(song.part, part_char2idx)
        
        output, (state_h, state_c) = model(part, k, m, (state_h, state_c))
        
        state_h = state_h.detach()
        state_c = state_c.detach()
        
        loss = criterion(output, target)
        
        loss.backward()
        
        _ = torch.nn.utils.clip_grad_norm_(
                model.parameters(), max_norm)
        
        optimizer.step()
        
        if iteration % print_every == 0:
            print('%s (%d %d%%) %.4f' % (timeSince(start), iteration, iteration / nb_iters * 100, loss.item()))
    print('\nGenerating example :\n')
    abc, op_notes = generate(model, start_part, start_k, start_m, length)    
    torch.save(model.state_dict(),'../models/model-{}.pth'.format(epoch))



0m 22s (34 10%) 3.9763
0m 48s (68 20%) 3.5880
1m 14s (102 30%) 3.4523
1m 40s (136 40%) 3.4755
2m 4s (170 50%) 3.4396
2m 30s (204 60%) 3.1974
2m 57s (238 70%) 3.5050
3m 22s (272 80%) 3.2567
3m 48s (306 90%) 2.7417
4m 11s (340 100%) 3.0368

Generating example :

M:6/8
K:G
"G"G2G"D"A2A|"G"B2cdBG|"C"E2Ec2B|"Am"ABG"D7"FED|"G"G2G"E7/b"e"Dm"=f"Dm"B=c3"F#m"Ac'2"A"A3^D"Am""c"e"G7"d2"F#m"A2"E7"g3"D"^f"D"A3"E7""b"g"A"AC3"A7"[e3-c3-]"Dm"F3"G/b"g3"G7"=B"B7"F2"A"C3"G"A2"B"e"D7/c"d3"D7"A"F"d3"G7"F3"A7/e"g2"G/b"b2A/2^D3"A"E2"Am/c"c2"E7"E"F#7"b2"F#m"A2c'2"F"F2e/2"Dm"e2"F#m"a2"C"f"F#m"e2"G/d"d"(A7)"F2"E7""b"g"B7"F2"D7/a"d"Eb"A"Bm"c2"D7"f3"C/e"eD"G/b"G3" ""f/="B"Am"g3"3"e"E"f3"D7"e3"F"d3"B"B,3=e"B7"B3"F#m"c"E7"=F"D"A2"G"F3"G""3"B"D7"c2"G/b"gf/4"C"e2"A7"f3"Am"A2"A7/e"a2"G/b"g2"C""e"g2^c"G7"b2"F#m"e"G""3"B"G"G4"D"=f"Bm"d2"Em"c"Gm"f2"Bb""f"f"D7"E2"Bm"d3"G7"g"A7"G3/2"Dm""f"A2"Bb"B,"C"e/2"Am"c2"A7"E"D7"G2"Bb"d2"C"f


4m 39s (34 10%) 3.1795
5m 4s (68 20%) 3.6589
5m 29s (102 30%) 3.0280
5m 54s (136 40%) 3.196

32m 19s (34 10%) 1.0565
32m 41s (68 20%) 1.2449
33m 3s (102 30%) 1.1323
33m 24s (136 40%) 0.6489
33m 44s (170 50%) 1.0618
34m 3s (204 60%) 0.9380
34m 22s (238 70%) 1.5350
34m 44s (272 80%) 1.0054
35m 3s (306 90%) 0.9471
35m 26s (340 100%) 0.9519

Generating example :

M:6/8
K:G
"G"G2G"D"A2A|"G"B2cdBG|"C"E2Ec2B|"Am"ABG"D7"FED|"G"G2G"D"D"E7/b"B"Gm"G3^G=gF3/2"Bb7"D"Em"b2"Bm"dA/2"Em"g/2"Bm""Bm"d2b/2"Bm"A"Am"A6"A"A3^c/2"A7"=g2"F#m"c"Am"c2"F"a2"B7"F"Eb"G3"Dm"g"A7/e"c2"Em""C"E3"F#m"e"D"D3/2"F"c'2[gA]"E"^G2"Am""F"A2"Am"A3"E7"F"G"b2[e2c2]"Am""F"A2"D"d/2B,=c"C"fc'/2"Dm"e2d'3/2"Am"f"Dm/f"f2"G"D3=B2"D7"d3/2"Gm"B"Am""F"A3"G7"A"D7/a"c2"F#m"a"Am"f"C7"c2"D"f3/2"Bm"F"C"A6"C"d2"F"=F3"A7"e3/2"C"g2"3"e"F"D"A7"[f2A2]"A7""F#m"a"A/c+"A"C"a2^F"E7"^G3"Am""F"cc'2"Bb"_b2" ""g"c"G"B2"Gm"f2"A"C3"C"^F3"Bm"e"A"c3"G7"^A2A/2"Eb"G3"Bb7"D"A7"e3"E"f3"Em"g2"F#m"e2"Bm"d3" ""d"B,2"E7"e3"C"A6=e"E7"A"Gm"G3"A7"d"Bm"D"D7"E2"G"a2"C"A6


35m 45s (34 10%) 0.5024
36m 3s (68 20%) 0.6337
36m 23s (102 30%) 1.1414
36m 43s (136 40%) 1.0

In [65]:
print(rep.songs[15])

X:15
T:The Barley Mow
S:Trad, arr Phil Rowe
M:6/8
L:
Q:
K:G
"G"G2G"D"A2A|"G"B2cdBG|"C"E2Ec2B|"Am"ABG"D7"FED|"G"G2G"D"A2A|"G"B2cdBG|"C"EcB"D7"AGF|"G"G6:|"G"d2d"C"e2e|"G"d2gdBG|"G"d2d"C"e2e|"G"dBG"D"A3|"G"d2d"C"e2e|"G"d2gg2f|"A7"egfed^c|"D7"d3D3||"G"G2G"D"A2A|"G"B2cdBG|"C"E2Ec2B|"Am"ABG"D7"FED|"G"G2G"D"A2A|"G"B2cdBG|"C"EcB"D7"AGF|"G"G6||


## Load model