In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.nn as nn

import torch.utils.data as data
import os
import random
import numpy as np
from tqdm import tqdm

import pypianoroll

In [3]:
#some constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.001
TRAIN_BATCH_SIZE = 30
VAL_BATCH_SIZE = 30
DATA_PATH = '../data/Nottingham/'
NUM_EPOCHS = 5
POSITIVE_WEIGHT = 2
CLIP_VALUE = 1.0 # clip value for the gradient clipping

In [4]:
def path_to_pianoroll(path, poisson=True, resolution = 4):
    #Resolution is set to 3 so that the sequences are not that long
    midi_data = pypianoroll.read(path, resolution=resolution)
    
    piano_roll = midi_data.blend()[:, 21:109] #Taking just 81 usefull notes
    
    #we want to perform multilabel classification at each step so we need to binaryze the roll
    piano_roll[piano_roll > 0] = 1
    
    if poisson:
        current_roll = piano_roll[np.newaxis,0, :] # to have shape (1, num_of_notes)
        count = 1
        counts = []
        new_piano_roll = current_roll
        for i in range(1, piano_roll.shape[0]):
            next_roll = piano_roll[np.newaxis, i, :]
            if np.all(current_roll == next_roll):
                count += 1
            else:
                counts.append(count)
                count = 1
                
                new_piano_roll = np.concatenate((new_piano_roll, next_roll), axis=0)
                
                current_roll = next_roll
                
        counts.append(count)
        new_piano_roll = np.concatenate((new_piano_roll, np.array(counts)[:,np.newaxis]), axis=1)
        return new_piano_roll 
                           
    return piano_roll
    

In [5]:
midi_path = os.path.join(DATA_PATH, "train", "ashover_simple_chords_21.mid")

roll = path_to_pianoroll(midi_path, False,resolution = 8)

roll2 = path_to_pianoroll(midi_path, True,resolution = 8)


In [6]:
print(roll.shape)
print(roll2.shape) # This has less timesteps and the difference will be bigger w.r.t resolution, and this has one additional entry in 2ndim dim which are counts
print(roll2[:,-1])
print(np.sum(roll2[:,-1])) # Now the improvement is much bigger

(512, 88)
(294, 89)
[2 1 1 1 2 1 2 1 2 1 1 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 2 1 2
 2 1 7 1 2 1 1 1 2 1 7 1 2 1 1 1 2 1 5 2 1 2 1 2 2 1 2 1 2 2 1 7 1 7 1 2 1
 1 1 2 1 2 1 2 1 1 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 2 1 2 2 1
 7 1 2 1 1 1 2 1 7 1 2 1 1 1 2 1 5 2 1 2 1 2 2 1 2 1 2 2 1 7 1 7 1 2 1 1 1
 2 1 2 1 2 1 1 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 2 1 2 2 1 7 1
 2 1 1 1 2 1 7 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 7 1 7 1 2 1 1
 1 2 1 2 1 2 1 1 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 2 1 2 2 1 7
 1 2 1 1 1 2 1 7 1 2 1 1 1 2 1 2 1 2 1 1 1 2 1 2 2 1 2 1 2 2 1 7 1 7 1]
512


In [7]:
def collate(batch):
    #Helper function for DataLoader
    #Batch is a list of tuple in the form (input, target)
    #We do not have to padd everything thanks to pack_sequence
    data = [item[0] for item in batch] #
    data = nn.utils.rnn.pack_sequence(data, enforce_sorted=False)
    targets = [item[1] for item in batch]
    targets = nn.utils.rnn.pack_sequence(targets, enforce_sorted=False)
    return [data, targets]

In [8]:
class NotesGenerationDataset(data.Dataset):
    
    def __init__(self, path,):
        
        self.path = path
        self.full_filenames = []
        
        #Here we assume that all midi files are valid, we do not check anything here.
        for root, subdirs, files in os.walk(path):
            for f in files:
                self.full_filenames.append(os.path.join(root, f))
                    
                        
    def __len__(self):
        return len(self.full_filenames)
    
    
    def __getitem__(self, index):
        full_filename = self.full_filenames[index]
        
        piano_roll = path_to_pianoroll(full_filename, poisson=True, resolution=8)
        
        #input and gt are shifted by one step w.r.t one another.
        #we transpose it since piano_roll has shape [num_of_notes, number_of_event] we want to have format [number of events, num_of_notes]
        input_sequence = piano_roll[:-1, :]
        ground_truth_sequence = piano_roll[1:, :]
        
        return torch.tensor(input_sequence, dtype=torch.float32), torch.tensor(ground_truth_sequence, dtype=torch.float32)

In [9]:
trainset = NotesGenerationDataset(os.path.join(DATA_PATH, "train"))

#ofc we want big batch_size. However, one training sample takes quite a lot of memory.
#We will use torch.cuda.amp.autocast() so that we can make bigger batches
trainset_loader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE,
                                              shuffle=True, drop_last=True, collate_fn=collate)

valset = NotesGenerationDataset(os.path.join(DATA_PATH, "valid"))

valset_loader = torch.utils.data.DataLoader(valset, batch_size=VAL_BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate)

In [10]:
#Small sanity check that our sets do not intersect at any moment
train_songs = set(trainset.full_filenames)
for song in valset.full_filenames:
    assert not song in train_songs

In [11]:
print(trainset.__len__())
assert len(os.listdir(os.path.join(DATA_PATH, "train"))) == trainset.__len__()

694


In [19]:
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_classes, n_layers=2):
        
        super(RNN, self).__init__()
        
        self.input_size = input_size # amount of different notes
        self.hidden_size = hidden_size
        self.num_classes = num_classes 
        self.n_layers = n_layers
        
        #At first we need layer that will encode our vector with only once to better representation
        self.notes_encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
        
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers)
        
        #At the end we want to get vector with logits of all notes
        self.logits_fc = nn.Linear(hidden_size, num_classes)
        
        self.poisson_fc = nn.Linear(hidden_size, 1) # I know that I could merge these into one layer, but this is more readable
        #If this solution will work well for the small dataset then I'll try to optimize it more 
    
    
    def forward(self, inp, hidden=None):
        
        if isinstance(inp, nn.utils.rnn.PackedSequence):
            #If we have Packed sequence we proceed a little bit differently
            batch_sizes = inp.batch_sizes
            notes_encoded = self.notes_encoder(inp.data) #PackedSequence.data is a tensor representation of shape [samples, num_of_notes]
            rnn_in = nn.utils.rnn.PackedSequence(notes_encoded,batch_sizes) #This is not recommended in PyTorch documentation.
            #However this saves a day here. Since otherwise we would have to create padded sequences 
            outputs, hidden = self.lstm(rnn_in, hidden)
            
            class_logits = self.logits_fc(outputs.data) #Again we go from packedSequence to tensor.
            poisson_logits = self.poisson_fc(outputs.data)
            
        else:
            #If we have tensor at the input this is pretty straightforward
            notes_encoded = self.notes_encoder(inp)
            outputs, hidden = self.lstm(notes_encoded, hidden)
            class_logits = self.logits_fc(outputs)
            poisson_logits = self.poisson_fc(outputs)
        
        return class_logits, poisson_logits, hidden

In [20]:
#Now sanity check about Packed Sequences. So I check if Unpacking -> packing the packed Sequence will lead to exactly the same Object.
inp, targets = next(iter(trainset_loader))

batch_sizes = inp.batch_sizes
inp2 = nn.utils.rnn.PackedSequence(inp.data, batch_sizes)
assert torch.all(torch.eq(inp.data, inp2.data))



In [21]:
rnn = RNN(input_size=89, hidden_size=256, num_classes=88) # 88 notes + 1 count is the input
rnn = rnn.to(DEVICE)

class_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full((88,), POSITIVE_WEIGHT, device=DEVICE))
poiss_criterion  = torch.nn.PoissonNLLLoss(log_input=True) # So this loss expect log(lambda) = x * b. Then it transforms it using exp.
#So it expcects some linear function, that's why we can give the output of the neural network directly

optimizer = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)

scaler = torch.cuda.amp.GradScaler()

In [22]:
def validate(rnn, class_criterion, poiss_criterion, loader, device):
    rnn.eval()
    loop = tqdm(loader, leave=True)
    
    losses_class = []
    losses_poisson = []
    
    with torch.no_grad():
        for idx, (inp, target) in enumerate(loop):
            inp, target = inp.to(device), target.to(device)
            target = target.data
        
            target_notes, target_poiss = target[:,:-1], target[:, -1]
            target_notes, target_poiss = target[:,:-1], target[:, -1]
            logits, logits_poisson, _ = rnn(inp)

            loss_class = class_criterion(logits, target_notes).item()
            loss_poiss = poiss_criterion(logits_poisson, target_poiss).item()
            
            losses_class.append(loss_class)
            losses_poisson.append(loss_poiss)
            loop.set_postfix(loss_class=loss_class, loss_pois=loss_poiss)

    rnn.train()
    return sum(losses_class) / len(losses_class), sum(losses_poisson) / len(losses_poisson)

In [23]:
def train(rnn, optimizer, class_criterion, poiss_criterion, loader, device, clip_value):
    loop = tqdm(loader, leave=True)
    
    losses_class = []
    losses_poisson = []
    
    for idx, (inp, target) in enumerate(loop):
        inp, target = inp.to(device), target.to(device)
        target = target.data
        #print(target.shape)
        target_notes, target_poiss = target[:,:-1], target[:, -1]
        #print(target_notes)
        #print(target_poiss)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast(): 
            logits, logits_poisson, _ = rnn(inp)
            
            loss_class = class_criterion(logits, target_notes)
            loss_poiss = poiss_criterion(logits_poisson, target_poiss)
            
            loss = loss_class + loss_poiss # This in general may be negative number.
             
        scaler.scale(loss).backward()
        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)
        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        torch.nn.utils.clip_grad_norm_(rnn.parameters(), clip_value)
        
        scaler.step(optimizer)
        scaler.update()
    
        loss_class, loss_poiss = loss_class.item(), loss_poiss.item()
        losses_class.append(loss_class)
        losses_poisson.append(loss_poiss)
        loop.set_postfix(loss_class=loss_class, loss_pois=loss_poiss)
        
    return sum(losses_class) / len(losses_class), sum(losses_poisson) / len(losses_poisson)

In [24]:
clip = 1.0
best_val_loss = float("inf")

train_losses_class = []
train_losses_poiss = []
val_losses_class = []
val_losses_poiss = []

for epoch_number in range(NUM_EPOCHS):
    train_class_loss, train_poiss_loss = train(rnn, optimizer, class_criterion, poiss_criterion, trainset_loader, DEVICE, CLIP_VALUE)    

    train_losses_class.append(train_class_loss)
    train_losses_poiss.append(train_poiss_loss)
    
    val_class_loss, val_poiss_loss = validate(rnn, class_criterion, poiss_criterion, valset_loader, DEVICE)

    val_losses_class.append(val_class_loss)
    val_losses_poiss.append(val_poiss_loss)
    
    
    print(f"Epoch {epoch_number}:\ntrain_class_loss: {train_class_loss}, train_poiss_loss: {train_poiss_loss}\n val_class_loss: {val_class_loss}, val_poiss_loss: {val_poiss_loss}")
    # if current_val_loss < best_val_loss:
        
    #     torch.save(rnn.state_dict(), 'music_rnn.pth')
    #     best_val_loss = current_val_loss

100%|██████████| 23/23 [00:18<00:00,  1.24it/s, loss_class=0.4, loss_pois=0.106]    
100%|██████████| 6/6 [00:04<00:00,  1.46it/s, loss_class=0.358, loss_pois=-.86] 


Epoch 0:
train_class_loss: 0.6035348008508268, train_poiss_loss: 0.2711894668476737
 val_class_loss: 0.35883162915706635, val_poiss_loss: -0.012340746819972992


100%|██████████| 23/23 [00:18<00:00,  1.22it/s, loss_class=0.167, loss_pois=0.0645] 
100%|██████████| 6/6 [00:04<00:00,  1.47it/s, loss_class=0.156, loss_pois=-.823]


Epoch 1:
train_class_loss: 0.20855830415435458, train_poiss_loss: 0.03980000206755231
 val_class_loss: 0.15972882757584253, val_poiss_loss: -0.012017610172430674


100%|██████████| 23/23 [00:18<00:00,  1.23it/s, loss_class=0.154, loss_pois=0.158] 
100%|██████████| 6/6 [00:03<00:00,  1.52it/s, loss_class=0.15, loss_pois=-.733] 


Epoch 2:
train_class_loss: 0.15601511558760767, train_poiss_loss: 0.03707794506993631
 val_class_loss: 0.15209558109442392, val_poiss_loss: 0.003101050853729248


100%|██████████| 23/23 [00:18<00:00,  1.23it/s, loss_class=0.158, loss_pois=0.0703]
100%|██████████| 6/6 [00:04<00:00,  1.45it/s, loss_class=0.149, loss_pois=-.792]


Epoch 3:
train_class_loss: 0.15251916258231454, train_poiss_loss: 0.04380273640803669
 val_class_loss: 0.15106557806332907, val_poiss_loss: -0.009540058672428131


 22%|██▏       | 5/23 [00:05<00:19,  1.08s/it, loss_class=0.153, loss_pois=0.0543]


KeyboardInterrupt: 

In [47]:
def sample_from_piano_rnn(sample_length=4, temperature=1, starting_sequence=None, deterministic = False, threshold=0.5):

    if starting_sequence is None:
        current_sequence_input = torch.zeros(1,1, 89, dtype=torch.float32, device=DEVICE)
        current_sequence_input[0, 0, 40] = 1
        current_sequence_input[0, 0, 50] = 1
        current_sequence_input[0, 0, 56] = 1
        current_sequence_input[0, 0, 88] = 1

    final_output_sequence = [current_sequence_input.squeeze(1)]
    
    hidden = None
    with torch.no_grad():
        for i in range(sample_length):
            current_sequence_input = torch.zeros(1,1, 89, dtype=torch.float32, device=DEVICE)
            
            logits_class, logits_poiss ,hidden = rnn(current_sequence_input, hidden)
            probabilities = torch.sigmoid(logits_class.div(temperature)) # The less the temperature the bigger probabilities of 1 will be
            if deterministic and len(final_output_sequence) > 5:
                current_sequence_input[0,0,:-1] = (probabilities > threshold).to(torch.float32)                
            else:
                prob_of_0 = 1 - probabilities
                dist = torch.stack((prob_of_0, probabilities), dim=3).squeeze() #Here we will get tensor [num_of_notes, 2]
                
                #from multinomial we have [num_of_notes, 1]. But eventually we want to have [1,1,num_of_notes]
                current_sequence_input[0,0,:-1] = torch.multinomial(dist, 1).squeeze().unsqueeze(0).unsqueeze(1).to(torch.float32)
                #print(current_sequence_input)
                #break

            lambda_ = np.exp(logits_poiss[0].item())
            repetitions = max([1, np.random.poisson(lambda_,1)[0]])
            #print(repetitions)
            current_sequence_input[0,0,-1] = repetitions
            final_output_sequence.append(current_sequence_input.squeeze(1))

    sampled_sequence = torch.cat(final_output_sequence, dim=0).cpu().numpy()
    
    return sampled_sequence

In [86]:
def poisson_to_piano_roll(poiss_roll):
    notes, counts = poiss_roll[np.newaxis,:,:-1], poiss_roll[:,-1]
    roll = np.repeat(notes[np.newaxis,0, 0], counts[0], axis=0)
    
    for timestep, count in zip(notes[0,1:,:], counts[1:]):
        roll = np.concatenate((roll, np.repeat(timestep[np.newaxis,:], count, axis=0)), axis=0)
        
    return roll

In [100]:
sample = sample_from_piano_rnn(sample_length=200, temperature=0.5, deterministic=False, threshold=0.3)

In [101]:
sample_roll = poisson_to_piano_roll(sample)

In [102]:
print(np.sum(sample[:,-1]))
sample_roll.shape

531.0


(531, 88)

In [103]:
roll = np.zeros((sample_roll.shape[0],128))
roll[:, 21:109] = sample_roll
roll[roll == 1] = 100
track = pypianoroll.Multitrack(resolution=3)
track.append(pypianoroll.StandardTrack(pianoroll=roll))
pypianoroll.write("sample2.mid", track)