In [64]:
from midi2seq import process_midi_seq, seq2piano, random_piano, piano2seq, segment
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import gdown

In [65]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps:0' if torch.backends.mps.is_available() else 'cpu')
device

device(type='mps', index=0)

In [47]:
expert_seq = process_midi_seq(maxlen=50,n=15000, shuffle_seed=3)
expert_seq.shape

(15734, 51)

In [48]:
seed = 10
fake_midix = [random_piano(seed) for i in range(20000)]
fake_seq = process_midi_seq(all_midis=fake_midix,maxlen=50,n=15000)
fake_seq.shape

(15006, 51)

In [49]:
critic_data = np.zeros((expert_seq.shape[0] + fake_seq.shape[0], expert_seq.shape[1]+1))

scaler = MinMaxScaler(feature_range=(0,1))

critic_data[:expert_seq.shape[0],:expert_seq.shape[1]] = expert_seq
critic_data[expert_seq.shape[0]:,:expert_seq.shape[1]] = fake_seq

# music = critic_data[:,:expert_seq.shape[1]]

# normalized_music = scaler.fit_transform(music.reshape((-1,1))).reshape(music.shape)
# print(f'max feature is {scaler.data_max_}')
# print(f'min feature is {scaler.data_min_}')

# critic_data[:,:expert_seq.shape[1]] = normalized_music

critic_data[:expert_seq.shape[0],expert_seq.shape[1]] = 1

critic_data, critic_data.shape

(array([[257., 346., 256., ..., 181., 277.,   1.],
        [ 51., 361.,  39., ..., 278., 256.,   1.],
        [256., 183., 299., ..., 359.,  43.,   1.],
        ...,
        [271., 256., 160., ...,  12., 268.,   0.],
        [266., 359.,  49., ..., 256., 362.,   0.],
        [268., 256., 140., ..., 365.,  94.,   0.]]),
 (30740, 52))

In [50]:
train_sequences, test_sequences = train_test_split(critic_data , test_size=0.2)

In [51]:
train_sequences[:,51].tolist().count(0),  train_sequences[:,51].tolist().count(1)

(12004, 12588)

In [52]:
X_train = train_sequences[:,:51]
X_train = X_train.reshape((-1,51,1))

Y_train = train_sequences[:,51]
Y_train = Y_train.reshape((-1,1))

X_test = test_sequences[:,:51]
X_test = X_test.reshape((-1,51,1))

Y_test = test_sequences[:,51]
Y_test = Y_test.reshape((-1,1))



X_train = torch.tensor(X_train).float().to(device) 
Y_train = torch.tensor(Y_train).float().to(device) 

X_test = torch.tensor(X_test).float().to(device) 
Y_test = torch.tensor(Y_test).float().to(device) 

X_train.shape, X_test.shape,  Y_train.shape, Y_test.shape

(torch.Size([24592, 51, 1]),
 torch.Size([6148, 51, 1]),
 torch.Size([24592, 1]),
 torch.Size([6148, 1]))

In [53]:
class MidiCriticDataset(Dataset):
    def __init__(self, X_sequence, Y_critic):
        self.X_sequence = X_sequence
        self.Y_critic = Y_critic

    def __len__(self):
        return len(self.Y_critic)
        
    def __getitem__(self, idx):
        sequence, label =  self.X_sequence[idx] ,self.Y_critic[idx]
        label = torch.tensor([1, 0]).float() if label else torch.tensor([0, 1]).float()
        return dict(
            sequence = sequence,
            label = label
        )

In [54]:
train_dataset = MidiCriticDataset(X_train,Y_train)
test_dataset = MidiCriticDataset(X_test,Y_test)

In [55]:
BATCH_SIZE = 100

train_loader = DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset,batch_size = BATCH_SIZE, shuffle=False)

In [56]:
for _, batch in enumerate(train_loader):
    sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
    print(sequence_batch.shape, label_batch.shape)
    break

torch.Size([100, 51, 1]) torch.Size([100, 2])


In [57]:
class SequenceModel(nn.Module):
    def __init__(self, n_classes, n_input=1, n_hidden=256, n_layers=3):
        super().__init__()
        self.num_stacked_layers = n_layers
        self.hidden_size = n_hidden
        
        self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden, num_layers=n_layers, batch_first=True, dropout=0.7)
        # Output layer
        self.linear = nn.Linear(n_hidden, n_classes)

    def forward(self, x):
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_stacked_layers, batch_size, self.hidden_size).to(device)
        
        lstm_out, _ = self.lstm(x, (h0, c0))
        out = lstm_out[:, -1, :]
        out = self.linear(out)
        return out

In [58]:
model = SequenceModel(2,1,256,3)
model.to(device)

SequenceModel(
  (lstm): LSTM(1, 256, num_layers=3, batch_first=True, dropout=0.7)
  (linear): Linear(in_features=256, out_features=2, bias=True)
)

In [59]:
learning_rate = 0.0001
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [60]:
def train_one_epoch():
    model.train(True)
    print(f'Epoch: {epoch + 1}')
    running_loss = 0.0
    
    for batch_index, batch in enumerate(train_loader):
        sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
        output = model(sequence_batch)
        loss = loss_function(output, label_batch)
        running_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 100 == 99:  # print every 100 batches
            avg_loss_across_batches = running_loss / 100
            print('Batch {0}, Loss: {1:.3f}'.format(batch_index+1,
                                                    avg_loss_across_batches))
            running_loss = 0.0
    print()

In [61]:
def validate_one_epoch():
    model.train(False)
    running_loss = 0.0
    
    for batch_index, batch in enumerate(test_loader):
        sequence_batch , label_batch = batch['sequence'].to(device) , batch['label'].to(device) 
        
        with torch.no_grad():
            output = model(sequence_batch)
            loss = loss_function(output, label_batch)
            running_loss += loss.item()

    avg_loss_across_batches = running_loss / len(test_loader)
    
    print('Val Loss: {0:.3f}'.format(avg_loss_across_batches))
    print('***************************************************')
    print()

In [62]:
train = True
import os
if train:
    num_epochs = 10
    for epoch in range(num_epochs):
        train_one_epoch()
        validate_one_epoch()
    torch.save(model, 'critic.pth')
elif os.path.isfile('critic.pth'):
    model = torch.load('critic.pth')
    model.eval()
else:
    url = 'https://drive.google.com/uc?id=1T0qNtWpvG-NO2lCOp_Y1YB6_Z3Xym0xn'
    output = 'critic.pth'
    gdown.download(url, output, quiet=False)
    
    # Model class must be defined somewhere
    model = torch.load('critic.pth')
    model.eval()

Epoch: 1
Batch 100, Loss: 0.518
Batch 200, Loss: 0.123

Val Loss: 0.097
***************************************************

Epoch: 2
Batch 100, Loss: 0.096
Batch 200, Loss: 0.072

Val Loss: 0.066
***************************************************

Epoch: 3
Batch 100, Loss: 0.068
Batch 200, Loss: 0.061

Val Loss: 0.055
***************************************************

Epoch: 4
Batch 100, Loss: 0.056
Batch 200, Loss: 0.051

Val Loss: 0.044
***************************************************

Epoch: 5
Batch 100, Loss: 0.045
Batch 200, Loss: 0.047

Val Loss: 0.044
***************************************************

Epoch: 6
Batch 100, Loss: 0.047
Batch 200, Loss: 0.047

Val Loss: 0.047
***************************************************

Epoch: 7
Batch 100, Loss: 0.043
Batch 200, Loss: 0.039

Val Loss: 0.037
***************************************************

Epoch: 8
Batch 100, Loss: 0.037
Batch 200, Loss: 0.041

Val Loss: 0.043
***************************************************



In [63]:
with torch.no_grad():
    output = model(X_test.to(device))
    predicted_index = torch.argmax(output, dim=1)
    predicted_index ^= 1 # index 0 is good and index 1 is bad 

RuntimeError: MPS backend out of memory (MPS allocated: 445.55 MB, other allocations: 5.14 GB, max allowed: 9.07 GB). Tried to allocate 3.59 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [95]:
arr = (predicted_index == torch.flatten(Y_test)).to('cpu').numpy() #copy to cpu before convert to numpy
final_test_acc = sum(arr)/len(arr)
final_test_acc

0.504037558685446