In [59]:
from midi2seq import process_midi_seq, seq2piano, random_piano, piano2seq, segment
import torch
import os
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import gdown

In [60]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps:0' if torch.backends.mps.is_available() else 'cpu')
device

device(type='mps', index=0)

In [61]:
seed = 10
expert_seq = process_midi_seq(maxlen=50,n=15000, shuffle_seed=seed)
expert_seq.shape

(17677, 51)

In [62]:
fake_midix = [random_piano(seed) for i in range(20000)]
fake_seq = process_midi_seq(all_midis=fake_midix,maxlen=50,n=15000)
fake_seq.shape

(15005, 51)

In [63]:
critic_data = np.zeros((expert_seq.shape[0] + fake_seq.shape[0], expert_seq.shape[1]+1))

critic_data[:expert_seq.shape[0],:expert_seq.shape[1]] = expert_seq
critic_data[expert_seq.shape[0]:,:expert_seq.shape[1]] = fake_seq

critic_data[:expert_seq.shape[0],expert_seq.shape[1]] = 1

critic_data, critic_data.shape

(array([[257., 355., 256., ...,  70., 256.,   1.],
        [ 53., 362.,  58., ..., 181., 264.,   1.],
        [190., 364.,  50., ..., 191., 269.,   1.],
        ...,
        [256., 368.,  65., ..., 284., 256.,   0.],
        [150., 282., 256., ..., 261., 256.,   0.],
        [256., 187., 257., ..., 364.,   7.,   0.]]),
 (32682, 52))

In [64]:
train_sequences, test_sequences = train_test_split(critic_data , test_size=0.2)

In [65]:
train_sequences[:,51].tolist().count(0),  train_sequences[:,51].tolist().count(1)

(12084, 14061)

In [71]:
X_train = train_sequences[:,:51]
X_train = X_train.reshape((-1,51,1))

Y_train = train_sequences[:,51]
Y_train = Y_train.reshape((-1,1))

X_test = test_sequences[:,:51]
X_test = X_test.reshape((-1,51,1))

Y_test = test_sequences[:,51]
Y_test = Y_test.reshape((-1,1))

X_train = torch.tensor(X_train).float().to(device) 
Y_train = torch.tensor(Y_train).float().to(device) 

X_test = torch.tensor(X_test).float().to(device) 
Y_test = torch.tensor(Y_test).float().to(device) 

X_train.shape, X_test.shape,  Y_train.shape, Y_test.shape

(torch.Size([26145, 51, 1]),
 torch.Size([6537, 51, 1]),
 torch.Size([26145, 1]),
 torch.Size([6537, 1]))

In [31]:
class MidiCriticDataset(Dataset):
    def __init__(self, X_sequence, Y_critic):
        self.X_sequence = X_sequence
        self.Y_critic = Y_critic

    def __len__(self):
        return len(self.Y_critic)
        
    def __getitem__(self, idx):
        sequence, label =  self.X_sequence[idx] ,self.Y_critic[idx]
        label = torch.tensor([1, 0]).float() if label else torch.tensor([0, 1]).float()
        return dict(
            sequence = sequence,
            label = label
        )

In [32]:
train_dataset = MidiCriticDataset(X_train,Y_train)
test_dataset = MidiCriticDataset(X_test,Y_test)

In [33]:
BATCH_SIZE = 100

train_loader = DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size = BATCH_SIZE, shuffle=False)

In [34]:
for _, batch in enumerate(train_loader):
    sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
    print(sequence_batch.shape, label_batch.shape)
    break

torch.Size([100, 51, 1]) torch.Size([100, 2])


In [47]:
class CriticModel(nn.Module):
    def __init__(self, n_classes, n_input=1, n_hidden=256, n_layers=3):
        super().__init__()
        self.num_stacked_layers = n_layers
        self.hidden_size = n_hidden
        
        self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden, num_layers=n_layers, batch_first=True, dropout=0.7)
        # Output layer
        self.fc = nn.Linear(n_hidden, n_classes)

    def forward(self, x):
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        
        lstm_out, _ = self.lstm(x, (h0, c0))
        out = lstm_out[:, -1, :]
        out = self.fc(out)
        return out

In [48]:
model = CriticModel(2,1,64,3)
model.to(device)

CriticModel(
  (lstm): LSTM(1, 64, num_layers=3, batch_first=True, dropout=0.7)
  (fc): Linear(in_features=64, out_features=2, bias=True)
)

In [37]:
learning_rate = 0.0001
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [38]:
def train_one_epoch():
    model.train(True)
    print(f'Epoch: {epoch + 1}')
    running_loss = 0.0
    
    for batch_index, batch in enumerate(train_loader):
        sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
        output = model(sequence_batch)
        loss = loss_function(output, label_batch)
        running_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 100 == 99:  # print every 100 batches
            avg_loss_across_batches = running_loss / 100
            print('Batch {0}, Loss: {1:.3f}'.format(batch_index+1,
                                                    avg_loss_across_batches))
            running_loss = 0.0
    print()

In [39]:
def validate_one_epoch():
    model.train(False)
    running_loss = 0.0
    
    for batch_index, batch in enumerate(test_loader):
        sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
        
        with torch.no_grad():
            output = model(sequence_batch)
            loss = loss_function(output, label_batch)
            running_loss += loss.item()

    avg_loss_across_batches = running_loss / len(test_loader)
    
    print('Val Loss: {0:.3f}'.format(avg_loss_across_batches))
    print('***************************************************')
    print()

In [56]:
train = False

if train:
    num_epochs = 100
    for epoch in range(num_epochs):
        train_one_epoch()
        validate_one_epoch()
    torch.save(model, 'critic.pth')
else:
    url = 'https://drive.google.com/uc?id=1Yla0ZkFQtPNZww8mdcPKDWNn7UfCDVJq'
    output = 'critic.pth'
    gdown.download(url, output, quiet=False)
    
# Model class must be defined somewhere
# model.load_state_dict(torch.load('critic.pth'))
# model.eval()

state_dict = torch.load('critic.pth').state_dict()
model.load_state_dict(state_dict)

DEBUG:Starting new HTTPS connection (1): drive.google.com:443
DEBUG:https://drive.google.com:443 "GET /uc?id=1Yla0ZkFQtPNZww8mdcPKDWNn7UfCDVJq HTTP/1.1" 303 0
DEBUG:Starting new HTTPS connection (1): doc-0o-8c-docs.googleusercontent.com:443
DEBUG:https://doc-0o-8c-docs.googleusercontent.com:443 "GET /docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/m2m9bqus6p29t90alnptp7iegj2ffhgp/1696884975000/02584426154643755225/*/1Yla0ZkFQtPNZww8mdcPKDWNn7UfCDVJq?uuid=118dfca9-0a56-4bb9-b393-a5c9af125b2e HTTP/1.1" 200 340890
Downloading...
From: https://drive.google.com/uc?id=1Yla0ZkFQtPNZww8mdcPKDWNn7UfCDVJq
To: /Users/edwardmorgan/Documents/dev/deeplearning/PianoGen/critic.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 341k/341k [00:00<00:00, 3.44MB/s]


<All keys matched successfully>

In [57]:
with torch.no_grad():
    output = model(X_test.to(device))
    predicted_index = torch.argmax(output, dim=1)
    predicted_index ^= 1 # index 0 is good and index 1 is bad 

In [58]:
arr = (predicted_index == torch.flatten(Y_test)).to('cpu').numpy() #copy to cpu before convert to numpy
final_test_acc = sum(arr)/len(arr)
final_test_acc

0.9888379204892966

In [78]:
torch.flatten(Y_test)

tensor([0., 0., 1.,  ..., 1., 0., 0.], device='mps:0')