In [1]:
from hw1 import Composer
from midi2seq import process_midi_seq, seq2piano, random_piano, piano2seq, segment
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset 
import torch.nn as nn
import numpy as np
import random
from sklearn.preprocessing import MinMaxScaler
import os
import gdown

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps:0' if torch.backends.mps.is_available() else 'cpu')
print('Using device:', device)

Using device: mps:0


In [3]:
sequence = process_midi_seq(maxlen=50, n=15000, shuffle_seed=3) #fixed shuffle_seed for debugging purpose and get fixed labels
print(sequence.shape)

notes = np.unique(sequence)
print(f'number of unique notes are {len(notes)} notes')

scaler = MinMaxScaler(feature_range=(0,1))

# Fitting scaler with the complete space and transforming the whole dataset on the scaler
normalized_sequence = scaler.fit_transform(sequence.reshape((-1,1))).reshape(sequence.shape)
print(f'max feature is {scaler.data_max_}')
print(f'min feature is {scaler.data_min_}')

normalized_notes = np.unique(normalized_sequence)
print(f'number of unique notes after normalization are {len(normalized_notes)}')

(15734, 51)
number of unique notes are 302 notes
max feature is [381.]
min feature is [21.]
number of unique notes after normalization are 302


In [4]:
X_train = normalized_sequence[:,:-1]
X_train = X_train.reshape((-1,X_train.shape[1],1))

Y_train = sequence[:,-1]
Y_train = Y_train.reshape((-1,1))

X_train = torch.tensor(X_train).float()
Y_train = torch.tensor(Y_train).float()

X_train.shape, Y_train.shape

(torch.Size([15734, 50, 1]), torch.Size([15734, 1]))

In [5]:
class MidiComposerDataset(Dataset):
    def __init__(self,labels, x_sequence, y_next):
        self.x_sequence = x_sequence
        self.y_next = y_next
        self.labels = labels

    def __len__(self):
        return len(self.y_next)

    def one_hot_encode(self, note):
        return torch.tensor(note == self.labels).float()
        
    def __getitem__(self, idx):
        action = self.y_next[idx][0].item()
        encode_action = self.one_hot_encode(action)
        return dict(
            sequence = self.x_sequence[idx],
            action = encode_action
        )

In [6]:
train_dataset = MidiComposerDataset(notes, X_train, Y_train)

In [7]:
BATCH_SIZE = 100

train_loader = DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle=True)

In [8]:
for _, batch in enumerate(train_loader):
    sequence_batch , action_batch = batch['sequence'].to(device) , batch['action'].to(device) 
    print(sequence_batch.shape, action_batch.shape)
    break

torch.Size([100, 50, 1]) torch.Size([100, 302])


In [9]:
class ComposerModel(nn.Module):
    def __init__(self, n_classes, n_input=1, n_hidden=256, n_layers=2):
        super().__init__()
        self.num_stacked_layers = n_layers
        self.hidden_size = n_hidden
        
        self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden, num_layers=n_layers, batch_first=True, dropout=0.2)
        self.dropout = nn.Dropout(0.2)
        # Output layer
        self.linear = nn.Linear(n_hidden, n_classes)

    def forward(self, x):
        batch_size = x.size(0)

        h0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        
        lstm_out, _ = self.lstm(x, (h0, c0))
        # take only the last output
        out = lstm_out[:, -1, :]
        # produce output
        out = self.linear(self.dropout(out))
        return out

In [10]:
classes = len(notes)
model = ComposerModel(classes,1,256, 2)
model.to(device)

ComposerModel(
  (lstm): LSTM(1, 256, num_layers=2, batch_first=True, dropout=0.2)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=302, bias=True)
)

In [11]:
learning_rate = 0.0001
loss_function = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
def train_one_epoch():
    model.train(True)
    print(f'Epoch: {epoch + 1}')
    running_loss = 0.0
    
    for batch_index, batch in enumerate(train_loader):
        sequence_batch , action_batch = batch['sequence'].to(device) , batch['action'].to(device)
        
        output = model(sequence_batch)
        loss = loss_function(output, action_batch)
        running_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 100 == 99:  # print every 100 batches
            avg_loss_across_batches = running_loss / 100
            print('Batch {0}, Loss: {1:.3f}'.format(batch_index+1,
                                                    avg_loss_across_batches))
            running_loss = 0.0
    print()

In [34]:
train = False

if train:
    num_epochs = 2000
    for epoch in range(num_epochs):
        train_one_epoch()
    torch.save(model, "composer.pth")
    state = {'epoch': num_epochs + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'losslogger': None}
    torch.save(state, "composer_checkpoint.pth.tar")
    
else:
    url = 'https://drive.google.com/uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k'
    output = 'composer_checkpoint.pth.tar'
    gdown.download(url, output, quiet=False)

DEBUG:Starting new HTTPS connection (1): drive.google.com:443
DEBUG:https://drive.google.com:443 "GET /uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k HTTP/1.1" 303 0
DEBUG:Starting new HTTPS connection (1): doc-0o-8c-docs.googleusercontent.com:443
DEBUG:https://doc-0o-8c-docs.googleusercontent.com:443 "GET /docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/igj56uf1leo26sr4um2l435krr6ghdo2/1696826250000/02584426154643755225/*/1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k?uuid=95bf70b8-3438-455e-967a-33941df1d697 HTTP/1.1" 200 10442462
Downloading...
From: https://drive.google.com/uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k
To: /Users/edwardmorgan/Documents/dev/deeplearning/PianoGen/composer_checkpoint.pth.tar
100%|███████████████████████████████████████████████████████████████████████████████████████████| 10.4M/10.4M [01:24<00:00, 124kB/s]


In [13]:
def load_checkpoint(model, optimizer, losslogger=None, filename='composer_checkpoint.pth.tar'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        losslogger = checkpoint['losslogger']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, losslogger

In [14]:
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer)
model = model.to(device)
# now individually transfer the optimizer parts...
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

=> loading checkpoint 'composer_checkpoint.pth.tar'
=> loaded checkpoint 'composer_checkpoint.pth.tar' (epoch 2001)


In [67]:
with torch.no_grad():
    rint = random.randint(0,sequence.shape[0]-1)
    print(rint)
    prompt_sequence = train_dataset.__getitem__(rint)['sequence']
    initial_sequence = prompt_sequence
    prompt_sequence = prompt_sequence.reshape((-1,prompt_sequence.shape[0],1))

    generated_sequence = prompt_sequence

    n_sequences = 50
    
    for i in range(n_sequences*50):
        output = model(prompt_sequence.to(device))
        predicted_index = int(torch.argmax(output, dim=1))
        predicted_note = normalized_notes[predicted_index]
        # New value to append
        new_value = torch.tensor([[[predicted_note]]], dtype=torch.float32)
        # Append the new value to the original tensor
        prompt_sequence = torch.cat((prompt_sequence, new_value), dim=1)
        prompt_sequence = prompt_sequence[:,1:,:]

        generated_sequence = torch.cat((generated_sequence, new_value), dim=1)
        
generated_sequence = np.rint(scaler.inverse_transform(generated_sequence.reshape((-1,1))))
midi = seq2piano(generated_sequence.reshape((-1,1)).flatten().astype(int))
midi.write('test436.midi')

10536


DEBUG:up without down for pitch 71 at time 0
DEBUG:up without down for pitch 73 at time 0
DEBUG:up without down for pitch 90 at time 0
DEBUG:consecutive downs for pitch 89 at time 0 and 1
DEBUG:up without down for pitch 57 at time 1
DEBUG:up without down for pitch 62 at time 1
DEBUG:up without down for pitch 62 at time 1
DEBUG:up without down for pitch 85 at time 1
DEBUG:consecutive downs for pitch 89 at time 0 and 1
DEBUG:consecutive downs for pitch 65 at time 1 and 1
DEBUG:up without down for pitch 92 at time 1
DEBUG:up without down for pitch 98 at time 1
DEBUG:up without down for pitch 74 at time 1
DEBUG:up without down for pitch 60 at time 1
DEBUG:consecutive downs for pitch 86 at time 1 and 1
DEBUG:consecutive downs for pitch 68 at time 1 and 1
DEBUG:consecutive downs for pitch 89 at time 0 and 1
DEBUG:up without down for pitch 86 at time 1
DEBUG:up without down for pitch 79 at time 1
DEBUG:consecutive downs for pitch 60 at time 1 and 1
DEBUG:up without down for pitch 72 at time 1