In [173]:
import torch
import numpy as np
import os
import time

In [174]:
# If we're using a GPU
dev = torch.device('cuda')

In [175]:
# If we're using a CPU
dev = torch.device('cpu')

In [176]:
# Custom Dataset class
class MIDIDataset(torch.utils.data.Dataset):
    # CONSTRUCTOR: creates an array of message chunks. Assumes that the directory
    # contains recording0.npy to recordingM.npy
    # ARGUMENTS
    # root_dir: the directory to search
    # chunk_size: we'll chunk the data into chunks of this size (or less)
    def __init__(self, root_dir, chunk_size, transform=None):
        self.chunks = []
        self.masks = []
        
        ch = 0
        for f, file in enumerate(os.listdir(root_dir)):
            data = np.load(root_dir + '/' + file)
            nchunks = int(np.floor(data.shape[0]/chunk_size))
            self.chunks += [torch.zeros(chunk_size, dtype=torch.long) for c in range(nchunks)]
            for chunk_start in range(0, data.shape[0], chunk_size):
                if chunk_start + chunk_size > data.shape[0]:
                    break
                self.chunks[ch] = torch.tensor(data[chunk_start:chunk_start + chunk_size])
                ch += 1
            
            if f%100 == 0:
                print(f)
            
        self.transform = transform

    # __len__
    # RETURN: the number of chunks in the dataset
    def __len__(self):
        return len(self.chunks)

    # __getitem__
    # ARGUMENTS
    # idx: indicates which chunk to get
    # RETURN: instance, a dictionary with keys 'messages' and 'mask'.
    # Both values associated with these keys are length L tensors
    def __getitem__(self, idx):  
        instance = self.chunks[idx]

        if self.transform:
            instance = self.transform(instance)
            
        return instance
    
def collate_fn(batch):
    chunk_size = batch[0].shape[0]
    sample = torch.zeros((chunk_size, len(batch)), dtype=torch.long)
    for b, instance in enumerate(batch):
        sample[:, b] = instance
        
    return sample.to(dev)

In [177]:
chunk_size = 600

In [178]:
train_data = MIDIDataset('train_vae_baseline', chunk_size)
len(train_data)

0
100
200
300
400


2414

In [179]:
test_data = MIDIDataset('test_vae_baseline', chunk_size)
len(test_data)

0


659

In [218]:
class VAEEncoder(torch.nn.Module):
    def __init__(self, message_dim, hidden_size, latent_size):
        super(VAEEncoder, self).__init__()
        self.gru = torch.nn.GRU(hidden_size, hidden_size, num_layers=2)
        self.mean = torch.nn.Linear(hidden_size, latent_size)
        self.logvar = torch.nn.Linear(hidden_size, latent_size)

    def forward(self, message_embed):
        L = message_embed.shape[0]
        B = message_embed.shape[1]
        output, hidden = self.gru(message_embed)
        z_mean = self.mean(output[-1]).view(1, B, -1)
        z_logvar = self.logvar(output[-1]).view(1, B, -1)
        return z_mean, z_logvar

class VAEDecoder(torch.nn.Module):
    def __init__(self, latent_size, hidden_size, message_dim, section_size):
        super(VAEDecoder, self).__init__()

        self.section_size = section_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.conductor_linear = torch.nn.Linear(latent_size, hidden_size)
        self.conductor_gru = torch.nn.GRU(1, hidden_size, num_layers=2)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, num_layers=2)
        self.logits = torch.nn.Linear(hidden_size, message_dim)

    # Return teacher forced logits
    def forward(self, z, forced_input_embed):
        L = forced_input_embed.shape[0]
        B = forced_input_embed.shape[1]
        assert(z.shape == (1, B, self.latent_size))
        
        nsections = int(np.ceil(L/self.section_size))
        
        conductor_hidden = torch.tanh(self.conductor_linear(z)).repeat(2, 1, 1)
        conductor_input = torch.zeros((1, B, 1)).to(dev) # Null input for conductor
        init_gru_input = torch.zeros((1, B, self.hidden_size)).to(dev)
        outputs = []
        for section in range(nsections):
            hid, conductor_hidden = self.conductor_gru(conductor_input, conductor_hidden)
            
            start = section*section_size
            end = min(start + section_size - 1, L)
            
            # Teacher forcing
            gru_inputs = torch.cat((init_gru_input, forced_input_embed[start:end]), dim=0)
            init_hidden = conductor_hidden.clone()
            output, hidden = self.gru(gru_inputs, init_hidden)
            outputs.append(output)
            
        return self.logits(torch.cat(outputs, dim=0))
    
    def generate(self, z, nsections, embedding):
        assert(z.shape == (1, 1, self.latent_size))
        
        conductor_hidden = torch.tanh(self.conductor_linear(z)).repeat(2, 1, 1)
        conductor_input = torch.zeros((1, 1, 1)).to(dev) # Null input for conductor
        gru_input = torch.zeros((1, 1, self.hidden_size)).to(dev)
        messages = []
        for section in range(nsections):
            hid, conductor_hidden = self.conductor_gru(conductor_input, conductor_hidden)
            hidden = conductor_hidden.clone()
            
            for message in range(self.section_size):
                output, hidden = self.gru(gru_input, hidden)
                probs = torch.nn.functional.softmax(self.logits(output), dim=2).flatten()
                messages.append(torch.multinomial(probs, 1))
                gru_input = embedding(messages[-1]).view(1, 1, -1)
            
        return messages
    
class BaselineVAE(torch.nn.Module):
    def __init__(self, message_dim, hidden_size, latent_size, section_size):
        super(BaselineVAE, self).__init__()
        
        self.message_dim = message_dim
        self.embedding = torch.nn.Embedding(message_dim, hidden_size)
        self.encoder = VAEEncoder(message_dim, hidden_size, latent_size)
        self.decoder = VAEDecoder(latent_size, hidden_size, message_dim, section_size)
        self.normal_sampler = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
        
    # Compute loss
    def forward(self, messages):
        L = messages.shape[0]
        B = messages.shape[1]
        
        message_embed = self.embedding(messages[:-1])
        z_mean, z_logvar = self.encoder(message_embed)
        z_std = torch.exp(0.5*z_logvar)
        eps = self.normal_sampler.sample(sample_shape=(1, B)).to(dev)
        z = z_mean + eps*z_std
        logits = self.decoder(z, message_embed)

        return torch.nn.functional.cross_entropy(logits.view(-1, self.message_dim), messages.flatten()) - \
               0.5*torch.sum(1 + z_logvar - z_mean*z_mean - torch.exp(z_logvar))
    
    # Reconstruct messages
    def reconstruct(self, messages):
        L = messages.shape[0]
        assert(messages.shape[1] == 1)
        
        message_embed = self.embedding(messages[:-1])
        z_mean, z_logvar = self.encoder(message_embed)
        z_std = torch.exp(0.5*z_logvar)
        eps = self.normal_sampler.sample(sample_shape=(1, 1)).to(dev)
        z = z_mean + eps*z_std
        logits = self.decoder(z, message_embed).view(-1, self.message_dim)
        return torch.multinomial(torch.nn.functional.softmax(logits, dim=1), 1)

In [181]:
# Optimization parameters:
epochs = 10
batch_size = 16
learning_rate = 1e-3

# Model parameters: 
message_dim = 388
hidden_size = 1024
latent_size = 512
section_size = 200

# Checkpoint location
checkpoint_dir = 'baseline_vae_checkpoints'

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
model = BaselineVAE(message_dim, hidden_size, latent_size, section_size).to(dev)
optimizer = torch.optim.Adam(model.parameters())

In [68]:
train_losses = [0 for epoch in range(epochs)]
test_losses = [0 for epoch in range(epochs)]
for epoch in range(epochs):
    print('Starting epoch %d' %(epoch))
    model.train()
    for b, batch in enumerate(train_dataloader):
        if b%10 == 0:
            print('Starting iteration %d' %(b))
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    torch.save(model.state_dict(), checkpoint_dir + '/epoch' + str(epoch) + '.pth')
        
    model.eval()
    print('Computing test loss')
    for b, batch in enumerate(test_dataloader):
        test_losses[epoch] += model(batch).item()
        
    test_losses[epoch] /= len(test_dataloader)
        
    print('Computing train loss')
    for b, batch in enumerate(train_dataloader):
        train_losses[epoch] += model(batch).item()
        
    train_losses[epoch] /= len(train_dataloader)
        
    print('Train loss: %f, Test loss: %f' %(train_losses[epoch], test_losses[epoch]))

Starting epoch 0
Starting iteration 0
Starting iteration 10
Starting iteration 20
Starting iteration 30
Starting iteration 40
Starting iteration 50
Starting iteration 60
Starting iteration 70
Starting iteration 80
Starting iteration 90
Starting iteration 100
Starting iteration 110
Starting iteration 120
Starting iteration 130
Starting iteration 140
Starting iteration 150
Starting iteration 160
Starting iteration 170
Starting iteration 180
Starting iteration 190
Starting iteration 200
Starting iteration 210
Starting iteration 220
Starting iteration 230
Starting iteration 240
Starting iteration 250
Starting iteration 260
Starting iteration 270
Starting iteration 280
Starting iteration 290
Starting iteration 300
Starting iteration 310
Starting iteration 320
Starting iteration 330
Starting iteration 340
Starting iteration 350
Starting iteration 360
Starting iteration 370
Starting iteration 380
Starting iteration 390
Starting iteration 400
Starting iteration 410
Starting iteration 420
Start

KeyboardInterrupt: 

In [None]:
np.save(checkpoint_dir + '/train_losses.npy', np.array(train_losses))
np.save(checkpoint_dir + '/test_losses.npy', np.array(test_losses))

In [None]:
train_losses = np.load(checkpoint_dir + '/train_losses.npy', allow_pickle=True)
test_losses = np.load(checkpoint_dir + '/test_losses.npy', allow_pickle=True)

In [None]:
import matplotlib.pyplot as plt
COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
plt.rcParams['xtick.color'] = COLOR
plt.rcParams['ytick.color'] = COLOR
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100 # 200 e.g. is really fine, but slower
plt.rcParams.update({'font.size': 22})
plt.plot(train_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Training Loss')
plt.savefig(checkpoint_dir + '/vae_baseline_train_loss.png')

In [None]:
plt.plot(test_losses)
plt.xlabel('Epochs')
plt.ylabel('Average Loss (Nats)')
plt.title('Training Loss')
plt.savefig(checkpoint_dir + '/vae_baseline_test_loss.png')

In [219]:
model = BaselineVAE(message_dim, hidden_size, latent_size, section_size).to(dev)
model.load_state_dict(torch.load('baseline_vae_checkpoints/epoch9.pth', map_location=dev))

<All keys matched successfully>

In [211]:
# Sample from model
def generate_music(model, gen_length=500):
    z = model.normal_sampler.sample(sample_shape=(1, 1, latent_size)).squeeze(3)
    nsections = int(np.ceil(gen_length/section_size))
    return model.decoder.generate(z, nsections, model.embedding)

In [212]:
generated_music = generate_music(model, gen_length=600)

In [213]:
np.save('recording_vae.npy', np.array(generated_music))

In [221]:
# Test reconstruction
original = train_data[0].unsqueeze(1)
reconstruction = model.reconstruct(original).flatten() # Should sound roughly the same as the original
np.save('baseline_vae_midis/original10.npy', original.flatten().detach().numpy())
np.save('baseline_vae_midis/reconstruction10.npy', reconstruction.detach().numpy())