In [0]:
#@title Download repository and dataset
# download repo and data
!git clone https://github.com/DTU-VAE/VAE.git
%cd /content/VAE
# !git checkout ****
import os
os.environ['PYTHONPATH'] += ":/content/VAE"

%cd /content
import requests
url = "https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip"
zip_file = requests.get(url)
with open("midi.zip", "wb") as zfile:
  zfile.write(zip_file.content)

import zipfile
with zipfile.ZipFile("midi.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/VAE/data/")

%cd /content/VAE/midi

print('\n\nScript usage\n------------------------------------')
!python3 midi.py -h

In [0]:
# start training without bootstrapping with default settings
!python3 midi.py --transpose-key

In [0]:
# start training with bootstrapping
!python3 midi.py --bootstrap ../model_states/model_epoch_1.tar --epochs 5 --log-interval 1000

Bootstrapping model from ../model_states/model_epoch_1.tar
Continuing training from epoch: 2

 40% 512/1282 [00:19<03:21,  3.82it/s]

In [0]:
# generate samples with the model
!python3 midi.py --bootstrap ../model_states/model_epoch_1.tar --generative

from IPython.display import Image, display
from pathlib import Path

path = '/content/VAE/results/sample/sample_epoch_generative.png'
my_file = Path(path)
if my_file.is_file():
    display(Image(path))

In [0]:
#@title Show reconstruction images. First half is original, second half is reconstruction
from IPython.display import Image, display
from pathlib import Path

print('Reconstructions\n-------------------------------\n')
for epoch in range(100):
    path = f'/content/VAE/results/reconstruction/reconstruction_epoch_{epoch}.png'
    my_file = Path(path)
    if my_file.is_file():
        print(f'Epoch: {epoch}')
        display(Image(path))
        print('\n')

In [0]:
#@title Show sample images.
from IPython.display import Image, display
from pathlib import Path

print('Samples\n-------------------------------\n')
for epoch in range(100):
    path = f'/content/VAE/results/sample/sample_epoch_{epoch}.png'
    my_file = Path(path)
    if my_file.is_file():
        print(f'Epoch: {epoch}')
        display(Image(path))
        print('\n')

In [0]:
#@title Plot losses for given epoch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

epoch = int(input('Epoch: '))
train_losses = np.load(f'/content/VAE/results/losses/train_loss_epoch_{epoch}.npy')
valid_losses = np.load(f'/content/VAE/results/losses/validation_loss_epoch_{epoch}.npy')
test_losses  = np.load(f'/content/VAE/results/losses/test_loss_epoch_{epoch}.npy')

avg_losses = [np.mean(train_losses),np.mean(valid_losses),np.mean(test_losses)]

plt.figure(figsize=(10,5))
plt.plot(train_losses, 'r--', label=f'train - mean: {avg_losses[0]}')
plt.plot(valid_losses, 'g-', label=f'validation - mean: {avg_losses[1]}')
plt.plot(test_losses,  'b-', label=f'test - mean: {avg_losses[2]}')
plt.grid()
plt.legend()
plt.title(f'Losses over time for epoch {epoch}')
plt.show()

In [0]:
#@title Download results and model states
from google.colab import files
!zip -r /content/model_states.zip /content/VAE/model_states
!zip -r /content/results.zip /content/VAE/results
files.download("/content/model_states.zip")
files.download("/content/results.zip")