In [0]:
#@title Download repository and dataset
# download repo and data
!git clone https://github.com/DTU-VAE/VAE.git
%cd /content/VAE
# !git checkout ****
import os
os.environ['PYTHONPATH'] += ":/content/VAE"

%cd /content
import requests
url = "https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip"
zip_file = requests.get(url)
with open("midi.zip", "wb") as zfile:
  zfile.write(zip_file.content)

import zipfile
with zipfile.ZipFile("midi.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/VAE/data/")

%cd /content/VAE/midi

print('\n\nScript usage\n------------------------------------')
!python3 midi.py -h

In [0]:
# start training without bootstrapping with default settings
!python3 midi.py

In [0]:
# start training with bootstrapping
!python3 midi.py --bootstrap ../model_states/model_epoch_1.tar --epochs 5 --log-interval 1000

In [0]:
#@title Show reconstruction images. First half is original, second half is reconstruction
from IPython.display import Image, display
from pathlib import Path

print('Reconstructions\n-------------------------------\n')
for epoch in range(100):
    path = f'/content/VAE/results/reconstruction/reconstruction_epoch_{epoch}.png'
    my_file = Path(path)
    if my_file.is_file():
        print(f'Epoch: {epoch}')
        display(Image(path))
        print('\n')

In [0]:
#@title Show sample images.
from IPython.display import Image, display
from pathlib import Path

print('Samples\n-------------------------------\n')
for epoch in range(100):
    path = f'/content/VAE/results/sample/sample_epoch_{epoch}.png'
    my_file = Path(path)
    if my_file.is_file():
        print(f'Epoch: {epoch}')
        display(Image(path))
        print('\n')

In [0]:
import numpy as np
import matplotlib.pylab as plt
from matplotlib import colors as c
import matplotlib.patches as mpatches
from scipy import signal

def plot_reconstruction(original, reconstruction, filename="recon_diff", conv_size=3, repeats=3):
	assert isinstance(original, np.ndarray), "'original' argument is not of type 'numpy.ndarray'"
	assert isinstance(reconstruction, np.ndarray), "'reconstruction' argument is not of type 'numpy.ndarray'"
	assert original.shape == reconstruction.shape, "Input arrays have unidentical shape"
	assert conv_size >= 3, "'conv_size' must be greater or equal to 3"
	assert conv_size % 2 == 1, "'conv_size' must be odd number"

	def save_input(src, extension):
		src_plot = src.copy()
		src_plot = src_plot.astype('bool')
		src_plot = np.repeat(src_plot, repeats, axis=1)
		plt.figure(figsize=(10,5))
		plt.axis("off")
		plt.imshow(src_plot, cmap='binary')
		plt.savefig(filename+"_"+extension, dpi=300)


	save_input(original, "original")
	save_input(reconstruction, "reconstruction")

	# Mask function
	def mask_op_source(src, mask, op):
		assert op in ["eq", "out"], "Undefined mask operation"

		src = src.astype('bool')
		mask = mask.astype('bool')

		if op == "eq":
			return np.logical_and(src, mask)
		if op == "out":
			return np.logical_and(mask, ~src)

	# Create image
	original = original.astype('uint8')
	reconstruction = reconstruction.astype('uint8')
	m_in = mask_op_source(original, reconstruction, "eq")
	m_out = mask_op_source(original, reconstruction, "out")
	
	# Convolutional operations
	conv_mask = np.ones((conv_size,conv_size), dtype="int8")
	conv_original = signal.convolve2d(original, conv_mask, mode="same", boundary="fill").astype("float")
	original = original.astype('float')
	original[original == 0] = np.NaN
	original[original == 1] = 0
	np.putmask(original, m_in, 1)
	conv_max = np.max(conv_original)
	conv_max = 1 if conv_max == 0 else conv_max
	conv_original -= conv_max
	conv_original /= conv_max
	conv_original[conv_original==0] = 1
	original[m_out] = conv_original[m_out]

	# Tile
	original = np.repeat(original, repeats, axis=1)

	# For print
	plt.figure(figsize=(10,5))
	cMap = c.ListedColormap(['#FF0000','#FF2A00', '#FF5500', '#FF8000', '#FFAA00', '#FFD500',
		'black','chartreuse','chartreuse','chartreuse','chartreuse','chartreuse','chartreuse'])
	patches = [ mpatches.Patch(color="black", label="Original signal"), mpatches.Patch(color="chartreuse", label="Exact reconstruction"), mpatches.Patch(color="#FF0000", label="Inexact reconstruction")]

	plt.axis("off")
	plt.legend(handles=patches, bbox_to_anchor=(.733, 1), loc=2, borderaxespad=0. )	
	plt.imshow(original, cmap=cMap, vmin=-1, vmax=1)

	plt.savefig(filename+"_print", dpi=300)


In [0]:
plot_reconstruction()

In [0]:
#@title Plot loss for given epoch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

epoch = int(input('Epoch: '))
train_losses = np.load(f'/content/VAE/results/losses/train_loss_epoch_{epoch}.npy')
valid_losses = np.load(f'/content/VAE/results/losses/validation_loss_epoch_{epoch}.npy')
test_losses  = np.load(f'/content/VAE/results/losses/test_loss_epoch_{epoch}.npy')

avg_losses = [np.mean(train_losses),np.mean(valid_losses),np.mean(test_losses)]

plt.figure(figsize=(10,5))
plt.plot(train_losses, 'r--', label=f'train - mean: {avg_losses[0]}')
plt.plot(valid_losses, 'g-', label=f'validation - mean: {avg_losses[1]}')
plt.plot(test_losses,  'b-', label=f'test - mean: {avg_losses[2]}')
plt.grid()
plt.legend()
plt.title(f'Losses over time for epoch {epoch}')
plt.show()

In [0]:
#@title Download results and model states
from google.colab import files
!zip -r /content/model_states.zip /content/VAE/model_states
!zip -r /content/results.zip /content/VAE/results
files.download("/content/model_states.zip")
files.download("/content/results.zip")