In [1]:
import os
import copy
import pickle
import random
from math import ceil

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from pprint import pprint

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

cuda:0


In [3]:
chord_to_index = {
    '000000000000': 48,
    '000000100001': 105,
    '000001000010': 106,
    '000001001001': 70,
    '000010000001': 100,
    '000010000100': 107,
    '000010001001': 1,
    '000010010001': 22,
    '000010010010': 71,
    '000100000010': 101,
    '000100001000': 108,
    '000100001001': 18,
    '000100010010': 2,
    '000100011001': 55,
    '000100100001': 6,
    '000100100010': 23,
    '000100100011': 98,
    '000100100100': 72,
    '000100100101': 37,
    '000100101001': 28,
    '000101001001': 87,
    '000110001001': 93,
    '000110010001': 59,
    '001000000100': 102,
    '001000001001': 67,
    '001000010000': 109,
    '001000010001': 10,
    '001000010010': 19,
    '001000100001': 15,
    '001000100011': 52,
    '001000100100': 3,
    '001000100101': 25,
    '001000101001': 84,
    '001000110001': 90,
    '001000110010': 56,
    '001001000001': 64,
    '001001000010': 7,
    '001001000100': 12,
    '001001000101': 81,
    '001001000110': 99,
    '001001001000': 73,
    '001001001001': 51,
    '001001001010': 38,
    '001001010001': 41,
    '001001010010': 29,
    '001010001001': 44,
    '001010010001': 32,
    '001010010010': 76,
    '001010101101': 113,
    '001010110101': 128,
    '001100010010': 94,
    '001100100010': 60,
    '010000001000': 103,
    '010000010010': 68,
    '010000100000': 110,
    '010000100010': 11,
    '010000100100': 20,
    '010001000010': 16,
    '010001000110': 53,
    '010001001000': 4,
    '010001001001': 47,
    '010001001010': 26,
    '010001010010': 85,
    '010001100010': 91,
    '010001100100': 57,
    '010010000010': 65,
    '010010000100': 8,
    '010010001000': 13,
    '010010001001': 35,
    '010010001010': 82,
    '010010001100': 88,
    '010010010000': 74,
    '010010010001': 79,
    '010010010010': 49,
    '010010010100': 39,
    '010010100010': 42,
    '010010100100': 30,
    '010010101011': 123,
    '010010101101': 126,
    '010100010010': 45,
    '010100100010': 33,
    '010100100100': 77,
    '010100101011': 124,
    '010101011010': 114,
    '010101101001': 116,
    '010101101010': 129,
    '010110100101': 118,
    '010110101001': 131,
    '011000100100': 95,
    '011001000100': 61,
    '011010010101': 120,
    '011010100101': 133,
    '100000010000': 104,
    '100000100100': 69,
    '100001000000': 111,
    '100001000100': 0,
    '100001001000': 21,
    '100010000100': 17,
    '100010001100': 54,
    '100010010000': 5,
    '100010010001': 97,
    '100010010010': 36,
    '100010010100': 27,
    '100010100100': 86,
    '100011000100': 92,
    '100011001000': 58,
    '100100000100': 66,
    '100100001000': 9,
    '100100010000': 14,
    '100100010001': 63,
    '100100010010': 24,
    '100100010100': 83,
    '100100011000': 89,
    '100100100000': 75,
    '100100100010': 80,
    '100100100100': 50,
    '100100101000': 40,
    '100101000100': 43,
    '100101001000': 31,
    '100101010110': 112,
    '100101011010': 127,
    '101000100100': 46,
    '101001000100': 34,
    '101001001000': 78,
    '101001010101': 122,
    '101001010110': 125,
    '101010010101': 135,
    '101010110100': 115,
    '101011010010': 117,
    '101011010100': 130,
    '101101001010': 119,
    '101101010010': 132,
    '110001001000': 96,
    '110010001000': 62,
    '110100101010': 121,
    '110101001010': 134
}

index_to_chord = {
    0: '100001000100',
    1: '000010001001',
    2: '000100010010',
    3: '001000100100',
    4: '010001001000',
    5: '100010010000',
    6: '000100100001',
    7: '001001000010',
    8: '010010000100',
    9: '100100001000',
    10: '001000010001',
    11: '010000100010',
    12: '001001000100',
    13: '010010001000',
    14: '100100010000',
    15: '001000100001',
    16: '010001000010',
    17: '100010000100',
    18: '000100001001',
    19: '001000010010',
    20: '010000100100',
    21: '100001001000',
    22: '000010010001',
    23: '000100100010',
    24: '100100010010',
    25: '001000100101',
    26: '010001001010',
    27: '100010010100',
    28: '000100101001',
    29: '001001010010',
    30: '010010100100',
    31: '100101001000',
    32: '001010010001',
    33: '010100100010',
    34: '101001000100',
    35: '010010001001',
    36: '100010010010',
    37: '000100100101',
    38: '001001001010',
    39: '010010010100',
    40: '100100101000',
    41: '001001010001',
    42: '010010100010',
    43: '100101000100',
    44: '001010001001',
    45: '010100010010',
    46: '101000100100',
    47: '010001001001',
    48: '000000000000',
    49: '010010010010',
    50: '100100100100',
    51: '001001001001',
    52: '001000100011',
    53: '010001000110',
    54: '100010001100',
    55: '000100011001',
    56: '001000110010',
    57: '010001100100',
    58: '100011001000',
    59: '000110010001',
    60: '001100100010',
    61: '011001000100',
    62: '110010001000',
    63: '100100010001',
    64: '001001000001',
    65: '010010000010',
    66: '100100000100',
    67: '001000001001',
    68: '010000010010',
    69: '100000100100',
    70: '000001001001',
    71: '000010010010',
    72: '000100100100',
    73: '001001001000',
    74: '010010010000',
    75: '100100100000',
    76: '001010010010',
    77: '010100100100',
    78: '101001001000',
    79: '010010010001',
    80: '100100100010',
    81: '001001000101',
    82: '010010001010',
    83: '100100010100',
    84: '001000101001',
    85: '010001010010',
    86: '100010100100',
    87: '000101001001',
    88: '010010001100',
    89: '100100011000',
    90: '001000110001',
    91: '010001100010',
    92: '100011000100',
    93: '000110001001',
    94: '001100010010',
    95: '011000100100',
    96: '110001001000',
    97: '100010010001',
    98: '000100100011',
    99: '001001000110',
    100: '000010000001',
    101: '000100000010',
    102: '001000000100',
    103: '010000001000',
    104: '100000010000',
    105: '000000100001',
    106: '000001000010',
    107: '000010000100',
    108: '000100001000',
    109: '001000010000',
    110: '010000100000',
    111: '100001000000',
    112: '100101010110',
    113: '001010101101',
    114: '010101011010',
    115: '101010110100',
    116: '010101101001',
    117: '101011010010',
    118: '010110100101',
    119: '101101001010',
    120: '011010010101',
    121: '110100101010',
    122: '101001010101',
    123: '010010101011',
    124: '010100101011',
    125: '101001010110',
    126: '010010101101',
    127: '100101011010',
    128: '001010110101',
    129: '010101101010',
    130: '101011010100',
    131: '010110101001',
    132: '101101010010',
    133: '011010100101',
    134: '110101001010',
    135: '101010010101'
}

In [4]:
class MelodyToChordSequenceDataset(Dataset):
    """
    This dataset loads the training data linking a 128 * 13 (optionally 14 with start bit)
    melody index input to a 128 * 12 multi-hot or 128 * 136 one-hot chord sequence output
    """

    def __init__(self, data_dir, seq_length=128, batch_size=24,
                 transpose=None, strip_rest_edge_measures=False, 
                 include_start_bit=False, chords_as_one_hot=False):
        """
        Args:
            data_dir (string):
                directory where .txt files of training data are stored
            transpose (list[int]):
                list of semitones to transpose the data by
            strip_rest_edge_measures:
                whether or not to skip measures before the first and after the last melody note
        """
        
        self.data_dir = data_dir
        self.transpose = transpose
        self.strip_rest_edge_measures = strip_rest_edge_measures
        self.include_start_bit = include_start_bit
#         self.chords_as_one_hot = chords_as_one_hot
        self.seq_length = seq_length
        self.batch_size = batch_size
        
#         self.number_of_chords = 0
#         if self.chords_as_one_hot:
#             self.chord_to_index_map = {}
#             self.index_to_chord_map = {}
        
        self._song_tuples = []
        
        for file in os.listdir(data_dir):
            with open(os.path.join(data_dir, file)) as f_in:
                file_text = f_in.read()
                
            lines = file_text.split("\n")
            lines = [line for line in lines if line != ""]
            
            song_has_melody = False
            for line in lines:
                if line.split()[0] != "12":
                    song_has_melody = True
                    break
                        
            if song_has_melody:
#                 if strip_rest_edge_measures:
#                     while lines[0].split()[0] == "12":
#                         lines.pop(0)

#                     while lines[-1].split()[0] == "12":
#                         lines.pop(-1)
                    
                song_melody = []
                song_harmony = []

                melody_quintet = []
                for i, line in enumerate(lines):
                    melody_index, harmony_vector = line.split()
                    melody_index = int(melody_index)
                    melody_step_vector = [1 if j == melody_index else 0 for j in range(13)]
                    if include_start_bit:
                        melody_step_vector.append(1 if i in range(16) else 0)
                        
                    melody_quintet.append(melody_step_vector)
                    
                    if i % 4 == 3:
                        song_melody.append(melody_quintet)
                        song_harmony.append([int(val) for val in harmony_vector])
                        melody_quintet = []
                    
#                     if self.chords_as_one_hot:
#                         harmony_multihot = "".join([str(int(val)) for val in harmony_vector])
#                         if harmony_multihot not in self.chord_to_index_map.keys():
#                             self.chord_to_index_map[harmony_multihot] = self.number_of_chords
#                             self.index_to_chord_map[self.number_of_chords] = harmony_multihot
#                             self.number_of_chords += 1
#                             if transpose:
#                                 for semitone_amount in transpose:
#                                     shifted_harmony_multihot = []
#                                     for j in range(12):
#                                         shifted_harmony_multihot.append(harmony_multihot[(j + semitone_amount) % 12])
#                                     shifted_harmony_multihot = "".join(shifted_harmony_multihot)
#                                     if shifted_harmony_multihot not in self.chord_to_index_map.keys():
#                                         self.chord_to_index_map[shifted_harmony_multihot] = self.number_of_chords
#                                         self.index_to_chord_map[self.number_of_chords] = shifted_harmony_multihot
#                                         self.number_of_chords += 1
                    
                song_melody = torch.tensor(song_melody, dtype=torch.uint8, device=device)
                song_harmony = torch.tensor(song_harmony, dtype=torch.uint8, device=device)

                self._song_tuples.append([song_melody, song_harmony])
                    
                if transpose:
                    for semitone_amount in transpose:
                        shifted_melody = copy.deepcopy(song_melody)
                        shifted_melody[:,:,:12] = shifted_melody[:,:,:12].roll(semitone_amount, dims=2)
                        shifted_harmony = copy.deepcopy(song_harmony).roll(semitone_amount, dims=1)

                        self._song_tuples.append([shifted_melody, shifted_harmony])
                    
            else:
                print("Skipping entry with no melody")
                
#         if self.chords_as_one_hot:
#             for i in range(len(self._song_tuples)):
#                 if i % 100 == 0:
#                     print(i)
#                 song_one_hot = []
#                 har_vectors = self._song_tuples[i][1]
#                 for har_vec in har_vectors:
#                     har_index = self.chord_to_index_map["".join([str(int(val)) for val in har_vec])]
#                     song_one_hot.append(har_index)
                    
#                 self._song_tuples[i][1] = torch.tensor(song_one_hot, dtype=torch.uint8, device=device)
                
        print(len(self._song_tuples))
        print(self._song_tuples[0][0].shape)
        print(self._song_tuples[0][1].shape)
            
        random.shuffle(self._song_tuples)

#         self.validation_set = self._song_tuples[int(1 * len(self._song_tuples)):]
        self._song_tuples = self._song_tuples[:int(1 * len(self._song_tuples))]
        
#         self.melody_sequence_val = torch.tensor([], dtype=torch.uint8, device=device)
#         self.harmony_sequence_val = torch.tensor([], dtype=torch.uint8, device=device)
        
#         for mel, har in self.validation_set:
#             self.melody_sequence_val = torch.cat((self.melody_sequence_val, mel))
#             self.harmony_sequence_val = torch.cat((self.harmony_sequence_val, har))
            
#         print(self.melody_sequence_val.shape)
#         print(self.harmony_sequence_val.shape)

#         self.melody_sequence_val = self.melody_sequence_val.view(1, -1)
#         self.harmony_sequence_val = self.harmony_sequence_val.view(1, -1)
        
#         print(self.melody_sequence_val.shape)
#         print(self.harmony_sequence_val.shape)
        
#         self.melody_sequence_val = self.melody_sequence_val.view(
#             self.batch_size, -1, 14 if self.include_start_bit else 13
#         )
#         self.harmony_sequence_val = self.harmony_sequence_val.view(
#             self.batch_size, -1, 1 if self.chords_as_one_hot else 12
#         )
        
#         print(self.melody_sequence_val.shape)
#         print(self.harmony_sequence_val.shape)
        
        self.melody_sequence = torch.tensor([], dtype=torch.uint8, device=device)
        self.harmony_sequence = torch.tensor([], dtype=torch.uint8, device=device)
                
        for mel, har in self._song_tuples:
            self.melody_sequence = torch.cat((self.melody_sequence, torch.flatten(mel, start_dim=1, end_dim=2)))
            self.harmony_sequence = torch.cat((self.harmony_sequence, har))
            
        print(self.melody_sequence.shape)
        print(self.harmony_sequence.shape)
        

        self.melody_sequence = self.melody_sequence.view(
            self.batch_size, -1, 52
        )
        self.harmony_sequence = self.harmony_sequence.view(
            self.batch_size, -1, 12
        )
        
        print(self.melody_sequence.shape)
        print(self.harmony_sequence.shape)
        
        
    def reshuffle_sequence(self):
        self.melody_sequence = torch.tensor([], dtype=torch.uint8, device=device)
        self.harmony_sequence = torch.tensor([], dtype=torch.uint8, device=device)
        
        random.shuffle(self._song_tuples)
        
        for mel, har in self._song_tuples:
            self.melody_sequence = torch.cat((self.melody_sequence, mel))
            self.harmony_sequence = torch.cat((self.harmony_sequence, har))

        self.melody_sequence = self.melody_sequence.view(
            self.batch_size, -1, 52
        )
        self.harmony_sequence = self.harmony_sequence.view(
            self.batch_size, -1, 12
        )
        
        
    def __len__(self):
        return ceil(self.melody_sequence.shape[1] / self.seq_length)
    
    
    def __getitem__(self, idx):
        melody, harmony = (self.melody_sequence[:, idx*self.seq_length:(idx+1)*self.seq_length],
                           self.harmony_sequence[:, idx*self.seq_length:(idx+1)*self.seq_length])
        
        return {"melody": melody, "harmony": harmony}
    
    
    def get_validation_set(self, idx):
        melody, harmony = (self.melody_sequence_val[:, idx*self.seq_length:(idx+1)*self.seq_length],
                           self.harmony_sequence_val[:, idx*self.seq_length:(idx+1)*self.seq_length])
        
        return {"melody": melody, "harmony": harmony}

In [5]:
training_data_fp = input("Input the directory where training files are stored: ")

Input the directory where training files are stored: ../../data/training_data/training_dt_longnotes


In [26]:
model_path = input("Input the name of the file you would like to store the model in: ")

Input the name of the file you would like to store the model in: ../../data/models/model10.pt


In [7]:
dataset = MelodyToChordSequenceDataset(
    training_data_fp,
    batch_size=24,
    transpose=[i for i in range(1, 12)],
)

dataloader = DataLoader(dataset)

len(dataloader)

Skipping entry with no melody
Skipping entry with no melody
Skipping entry with no melody
Skipping entry with no melody
Skipping entry with no melody
Skipping entry with no melody
2328
torch.Size([732, 4, 13])
torch.Size([732, 12])
torch.Size([989136, 52])
torch.Size([989136, 12])
torch.Size([24, 41214, 52])
torch.Size([24, 41214, 12])


322

In [8]:
positive_examples = torch.zeros([12], device=device)

for datum in dataloader:
    datum["harmony"].squeeze_()
    data_batch = datum["harmony"].to(torch.float).sum(1).sum(0)
    positive_examples += data_batch

negative_examples = (len(dataloader) * 24 * 128) - positive_examples

positive_weights = torch.div(negative_examples, positive_examples).flatten()

print(positive_weights)

tensor([2.9463, 2.9463, 2.9463, 2.9463, 2.9463, 2.9463, 2.9463, 2.9463, 2.9463,
        2.9463, 2.9463, 2.9463], device='cuda:0')


In [9]:
# # Define recurrent prediction model
# class LSTMGenerator(nn.Module):
    
#     def __init__(self):
#         super(LSTMGenerator, self).__init__()
#         self.embedding = nn.Embedding(13, 100)
#         self.lstm = nn.LSTM(input_size=100, hidden_size=256, num_layers=2, batch_first=True)
#         self.fc1 = nn.Linear(256, 12)
        
#     def forward(self, x, hidden_in):
#         x = self.embedding(x)
#         x = x.view(48, -1, 100)
#         x, h_out = self.lstm(x, hidden_in)
#         x = self.fc1(x)
#         return x, h_out

# net = LSTMGenerator()

# net.to(device)
# print(net)

In [10]:
# # Define recurrent prediction model
# # NOTE: TOOK OUT EMBEDDING LAYER, MODEL 5 wont work with this spec
# class LSTMGenerator_v2(nn.Module):
    
#     def __init__(self):
#         super(LSTMGenerator_v2, self).__init__()
#         self.lstm = nn.LSTM(input_size=13, 
#                             hidden_size=256, 
#                             num_layers=2, 
#                             batch_first=True, 
#                             bidirectional=True, 
#                             dropout=0.2)
#         self.fc1 = nn.Linear(512, 256)
#         self.fc2 = nn.Linear(256, 12)
        
#     def forward(self, x, hidden_in):
#         x, h_out = self.lstm(x, hidden_in)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.fc2(x)
        
#         return x, h_out

# net = LSTMGenerator_v2()

# net.to(device)
# print(net)

In [11]:
# # Define recurrent prediction model
# class LSTMGenerator_v3(nn.Module):
#     def __init__(self):
#         super(LSTMGenerator_v3, self).__init__()
#         self.lstm = nn.LSTM(input_size=14, 
#                             hidden_size=256, 
#                             num_layers=2, 
#                             batch_first=True, 
#                             bidirectional=True, 
#                             dropout=0.2)
#         self.fc1 = nn.Linear(512, 256)
#         self.fc2 = nn.Linear(256, 12)
        
#     def forward(self, x, hidden_in):
#         x, h_out = self.lstm(x, hidden_in)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.fc2(x)
        
#         return x, h_out

# net = LSTMGenerator_v3()

# net.to(device)
# print(net)

In [12]:
# # Define recurrent prediction model
# class LSTMGenerator_v4(nn.Module):
#     def __init__(self):
#         super(LSTMGenerator_v4, self).__init__()
#         self.lstm = nn.LSTM(input_size=13, 
#                             hidden_size=256, 
#                             num_layers=2, 
#                             batch_first=True, 
#                             bidirectional=True, 
#                             dropout=0.2)
#         self.fc1 = nn.Linear(512, 256)
#         self.fc2 = nn.Linear(256, 136)
#         self.ls = nn.LogSoftmax()
        
#     def forward(self, x, hidden_in):
#         x, h_out = self.lstm(x, hidden_in)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.fc2(x)
#         x = self.ls(x)
        
#         return x, h_out

# net = LSTMGenerator_v4()

# net.to(device)
# print(net)

In [22]:
# Define recurrent prediction model
class LSTMGenerator_v5(nn.Module):
    def __init__(self):
        super(LSTMGenerator_v5, self).__init__()
        self.lstm = nn.LSTM(input_size=52, 
                            hidden_size=256, 
                            num_layers=2, 
                            batch_first=True, 
                            bidirectional=True, 
                            dropout=0.2)
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 12)
        
    def forward(self, x, hidden_in):
        x, h_out = self.lstm(x, hidden_in)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        
        return x, h_out

net = LSTMGenerator_v5()

net.to(device)
print(net)

pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(pytorch_total_params)

LSTMGenerator_v5(
  (lstm): LSTM(52, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=12, bias=True)
)
2346252


In [24]:
# Train recurrent model

criterion = nn.BCEWithLogitsLoss()
# criterion = nn.NLLLoss()
optimizer = optim.Adam(net.parameters())

for epoch in range(250): 
    # TODO, implement early stopping with a validation holdout
    running_loss = 0.0
    
    dataset.reshuffle_sequence()
    
    for i, data in enumerate(dataloader):
        hidden_out = (torch.rand(2 * 2, 24, 256, device=device),
                      torch.rand(2 * 2, 24, 256, device=device))
        
        # get the inputs; data is a list of [inputs, labels]
        inputs = data["melody"].squeeze_(dim=0).to(dtype=torch.float)
        labels = data["harmony"].squeeze_(dim=0).to(dtype=torch.float)
        
#         one_hot_expansion = []
#         for sequence in labels:
#             batch_expansion = []
#             for index in sequence:
#                 batch_expansion.append([1 if j == index else 0 for j in range(136)])
#             one_hot_expansion.append(batch_expansion)
            
#         labels = torch.tensor(one_hot_expansion, dtype=torch.float, device=device)
        
#         print(inputs.shape)
#         print(labels.shape)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs, hidden_out = net(inputs, hidden_out)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        hidden_out = (hidden_out[0].detach(), hidden_out[1].detach())

        running_loss += loss.item()
        
#     running_val_loss = 0
#     for i in range(len(dataloader.validation_set)):
#             data = dataloader.get_validation_set(i)
#             inputs = data["melody"].squeeze_().to(dtype=torch.float)
#             labels = data["harmony"].squeeze_().to(dtype=torch.float)
            
#             outputs, _ = net(inputs, hidden_out)
#             val_loss = criterion(outputs, labels)
            
#             running_val_loss += loss.item()
    
    # print statistics    
    print('[{}] loss: {}'.format(epoch + 1, running_loss / len(dataloader)))
#     print('[{}] validation loss: {}'.format(epoch + 1, running_val_loss / len(dataloader.validation_set)))

print('Finished Training')

[1] loss: 0.49870256172574084
[2] loss: 0.45334619533571396
[3] loss: 0.4418628060299417
[4] loss: 0.4360727612276255
[5] loss: 0.43282970591731695
[6] loss: 0.42766388501069563
[7] loss: 0.42559942677154305
[8] loss: 0.4211236955771535
[9] loss: 0.41881091050479724
[10] loss: 0.4131768362492508
[11] loss: 0.40871451980208756
[12] loss: 0.40399421936606766
[13] loss: 0.4004759283169456
[14] loss: 0.39550783276928136
[15] loss: 0.3895345901295265
[16] loss: 0.3837540406062736
[17] loss: 0.3781624823253347
[18] loss: 0.3696305239978044
[19] loss: 0.3621245985445769
[20] loss: 0.3520688649660312
[21] loss: 0.3425037948796468
[22] loss: 0.3323477803735259
[23] loss: 0.32226525973644315
[24] loss: 0.31058716551857707
[25] loss: 0.3016244199427759
[26] loss: 0.29083006023805336
[27] loss: 0.28205288159921305
[28] loss: 0.27314736698725206
[29] loss: 0.2664947947545081
[30] loss: 0.2577383416409818
[31] loss: 0.2518176865411101
[32] loss: 0.24571940122368913
[33] loss: 0.23909556463083126
[34

In [27]:
torch.save(net.state_dict(), model_path)