In [85]:
from hw1 import Composer
from midi2seq import process_midi_seq, seq2piano, random_piano, piano2seq, segment
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset 
import torch.nn as nn
import numpy as np
import random
from sklearn.preprocessing import MinMaxScaler
import os
import gdown

In [86]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps:0' if torch.backends.mps.is_available() else 'cpu')
print('Using device:', device)

Using device: mps:0


In [87]:
sequence = process_midi_seq(maxlen=50, n=15000, shuffle_seed=3) #fixed shuffle_seed for debugging purpose and get fixed labels
print(sequence.shape)

notes = np.unique(sequence)
print(f'number of unique notes are {len(notes)} notes')

scaler = MinMaxScaler(feature_range=(0,1))

# Fitting scaler with the complete space and transforming the whole dataset on the scaler
normalized_sequence = scaler.fit_transform(sequence.reshape((-1,1))).reshape(sequence.shape)
print(f'max feature is {scaler.data_max_}')
print(f'min feature is {scaler.data_min_}')

normalized_notes = np.unique(normalized_sequence)
print(f'number of unique notes after normalization are {len(normalized_notes)}')

(15734, 51)
number of unique notes are 302 notes
max feature is [381.]
min feature is [21.]
number of unique notes after normalization are 302


In [88]:
X_train = normalized_sequence[:,:-1]
X_train = X_train.reshape((-1,X_train.shape[1],1))

Y_train = sequence[:,-1]
Y_train = Y_train.reshape((-1,1))

X_train = torch.tensor(X_train).float()
Y_train = torch.tensor(Y_train).float()

X_train.shape, Y_train.shape

(torch.Size([15734, 50, 1]), torch.Size([15734, 1]))

In [89]:
class MidiComposerDataset(Dataset):
    def __init__(self,labels, x_sequence, y_next):
        self.x_sequence = x_sequence
        self.y_next = y_next
        self.labels = labels

    def __len__(self):
        return len(self.y_next)

    def one_hot_encode(self, note):
        return torch.tensor(note == self.labels).float()
        
    def __getitem__(self, idx):
        action = self.y_next[idx][0].item()
        encode_action = self.one_hot_encode(action)
        return dict(
            sequence = self.x_sequence[idx],
            action = encode_action
        )

In [90]:
train_dataset = MidiComposerDataset(notes, X_train, Y_train)

In [91]:
BATCH_SIZE = 100

train_loader = DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle=True)

In [92]:
for _, batch in enumerate(train_loader):
    sequence_batch , action_batch = batch['sequence'].to(device) , batch['action'].to(device) 
    print(sequence_batch.shape, action_batch.shape)
    break

torch.Size([100, 50, 1]) torch.Size([100, 302])


In [93]:
class ComposerModel(nn.Module):
    def __init__(self, n_classes, n_input=1, n_hidden=256, n_layers=2):
        super().__init__()
        self.num_stacked_layers = n_layers
        self.hidden_size = n_hidden
        
        self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden, num_layers=n_layers, batch_first=True, dropout=0.2)
        self.dropout = nn.Dropout(0.2)
        # Output layer
        self.linear = nn.Linear(n_hidden, n_classes)

    def forward(self, x):
        batch_size = x.size(0)

        h0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        
        lstm_out, _ = self.lstm(x, (h0, c0))
        # take only the last output
        out = lstm_out[:, -1, :]
        # produce output
        out = self.linear(self.dropout(out))
        return out

In [94]:
classes = len(notes)
model = ComposerModel(classes,1,256, 2)
model.to(device)

ComposerModel(
  (lstm): LSTM(1, 256, num_layers=2, batch_first=True, dropout=0.2)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=302, bias=True)
)

In [95]:
learning_rate = 0.0001
loss_function = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [96]:
def train_one_epoch():
    model.train(True)
    print(f'Epoch: {epoch + 1}')
    running_loss = 0.0
    
    for batch_index, batch in enumerate(train_loader):
        sequence_batch , action_batch = batch['sequence'].to(device) , batch['action'].to(device)
        
        output = model(sequence_batch)
        loss = loss_function(output, action_batch)
        running_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 100 == 99:  # print every 100 batches
            avg_loss_across_batches = running_loss / 100
            print('Batch {0}, Loss: {1:.3f}'.format(batch_index+1,
                                                    avg_loss_across_batches))
            running_loss = 0.0
    print()

In [97]:
train = False

if train:
    num_epochs = 2000
    for epoch in range(num_epochs):
        train_one_epoch()
    torch.save(model, "composer.pth")
    state = {'epoch': num_epochs + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'losslogger': None}
    torch.save(state, "composer_checkpoint.pth.tar")
    
else:
    url = 'https://drive.google.com/uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k'
    output = 'composer_checkpoint.pth.tar'
    gdown.download(url, output, quiet=False)

DEBUG:Starting new HTTPS connection (1): drive.google.com:443
DEBUG:https://drive.google.com:443 "GET /uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k HTTP/1.1" 303 0
DEBUG:Starting new HTTPS connection (1): doc-0o-8c-docs.googleusercontent.com:443
DEBUG:https://doc-0o-8c-docs.googleusercontent.com:443 "GET /docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ov7ee1rs9ggq1e5paa3utcgbf5h80rl1/1696819275000/02584426154643755225/*/1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k?uuid=b0c126a4-b1bd-4daa-8a8b-43c48b6eef6a HTTP/1.1" 200 10442462
Downloading...
From: https://drive.google.com/uc?id=1sd_YLCoHVqVqYhmkgnMbSAH1MzfmpO6k
To: /Users/edwardmorgan/Documents/dev/deeplearning/PianoGen/composer_checkpoint.pth.tar
100%|██████████████████████████████████████████████████████████████████████████████████████████| 10.4M/10.4M [00:02<00:00, 4.52MB/s]


In [98]:
def load_checkpoint(model, optimizer, losslogger=None, filename='composer_checkpoint.pth.tar'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        losslogger = checkpoint['losslogger']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, losslogger

In [99]:
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer)
model = model.to(device)
# now individually transfer the optimizer parts...
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

=> loading checkpoint 'composer_checkpoint.pth.tar'
=> loaded checkpoint 'composer_checkpoint.pth.tar' (epoch 2001)


In [100]:
with torch.no_grad():
    rint = random.randint(0,sequence.shape[0]-1)
    prompt_sequence = train_dataset.__getitem__(rint)['sequence']

    prompt_sequence = prompt_sequence.reshape((-1,prompt_sequence.shape[0],1))

    generated_sequence = prompt_sequence

    n_sequences = 50
    
    for i in range(n_sequences*50):
        output = model(prompt_sequence.to(device))
        predicted_index = int(torch.argmax(output, dim=1))
        predicted_note = normalized_notes[predicted_index]
        # New value to append
        new_value = torch.tensor([[[predicted_note]]], dtype=torch.float32)
        # Append the new value to the original tensor
        prompt_sequence = torch.cat((prompt_sequence, new_value), dim=1)
        prompt_sequence = prompt_sequence[:,1:,:]

        generated_sequence = torch.cat((generated_sequence, new_value), dim=1)
        
generated_sequence = np.rint(scaler.inverse_transform(generated_sequence.reshape((-1,1))))
generated_sequence

array([[257.],
       [201.],
       [259.],
       ...,
       [369.],
       [ 74.],
       [370.]])

In [114]:
for e in generated_sequence.flatten():
    print(int(e))

257
201
259
256
373
35
372
47
260
256
175
163
263
374
44
373
32
373
63
372
68
373
71
258
256
191
196
172
160
199
262
373
46
373
34
259
256
174
257
256
162
263
375
72
377
84
373
75
372
42
373
70
373
48
373
39
368
38
257
200
171
202
202
179
208
189
181
215
193
260
373
62
256
372
41
256
35
259
256
178
190
372
56
258
256
40
257
256
369
38
259
256
166
256
256
256
182
367
39
257
360
66
257
185
258
183
257
256
368
70
259
369
64
373
77
256
202
258
197
262
366
76
257
256
256
362
54
256
201
366
53
258
182
257
367
60
263
256
195
257
256
367
50
256
165
258
182
257
367
67
366
72
257
256
257
197
263
367
62
367
65
258
370
66
257
256
189
259
256
369
69
256
369
69
261
256
183
185
261
256
366
52
256
256
194
257
256
362
60
263
201
257
256
195
259
256
367
79
259
256
205
258
256
206
257
256
368
74
373
61
262
256
363
84
257
256
360
78
200
365
57
261
256
362
89
262
256
206
278
256
210
263
196
267
256
267
209
267
256
92
257
256
364
68
288
360
60
267
256
256
256
256
361
66
365
48
363
256
192
258
256
363
67
268

In [117]:
midi = seq2piano(generated_sequence.flatten())

DEBUG:up without down for pitch 73 at time 0
DEBUG:up without down for pitch 43 at time 0
DEBUG:up without down for pitch 74 at time 0
DEBUG:up without down for pitch 74 at time 0
DEBUG:up without down for pitch 51 at time 0
DEBUG:up without down for pitch 80 at time 0
DEBUG:up without down for pitch 61 at time 0
DEBUG:up without down for pitch 53 at time 0
DEBUG:up without down for pitch 87 at time 0
DEBUG:up without down for pitch 65 at time 0
DEBUG:up without down for pitch 50 at time 0
DEBUG:consecutive downs for pitch 38 at time 0 and 0
DEBUG:up without down for pitch 54 at time 0
DEBUG:consecutive downs for pitch 39 at time 0 and 0
DEBUG:up without down for pitch 57 at time 0
DEBUG:up without down for pitch 55 at time 0
DEBUG:consecutive downs for pitch 70 at time 0 and 0
DEBUG:up without down for pitch 74 at time 0
DEBUG:up without down for pitch 69 at time 0
DEBUG:up without down for pitch 73 at time 1
DEBUG:up without down for pitch 67 at time 1
DEBUG:up without down for pitch