In [4]:
import os
import torch
import numpy as np
import torch.nn as nn
import muse.supplier as spr
import muse.processor2 as pcr
import muse.model2 as mdl
import muse.trainer as trn
import muse.visualizer as vis

import warnings
warnings.filterwarnings("ignore")

In [5]:
### ===== Hyperparameters ===== ###

instrument = 'Piano'
filepath = "./../input/beeth/"
song_len = 200
stride = 200
device = trn.get_device()

seed_load = 592643464

In [6]:
### ===== Data Preprocessing ===== ###

seed_load = np.random.randint(0,999999999)
seed_load = 592643464 # set to seed that works

all_midis, filenames = pcr.get_midis(filepath) # load all .midi files
Corpus, instru2corpus = pcr.get_notes_batch(all_midis) # extract all notes and sort by instrument

Corpus, fmap, rmap = pcr.get_map(Corpus) # get forward-map and reverse-map from corpus
Corpus2, fmap2, rmap2 = pcr.get_map_offset_v2(instru2corpus, instrument)
instru2corpus = pcr.remove_short(instru2corpus) # remove songs that are too short

X_train_melody, X_val_melody, X_train_offset, X_val_offset = pcr.train_test_split(instru2corpus, instrument, fmap, song_len, stride,\
                                                                                  seed=seed_load, process='center')
X_train_melody, X_val_melody = pcr.batchify(X_train_melody), pcr.batchify(X_val_melody) # reshape and turn into tensor
X_train_offset, X_val_offset = pcr.batchify(X_train_offset), pcr.batchify(X_val_offset) # reshape and turn into tensor

X_train_offset = pcr.fmap_offset(X_train_offset, fmap2, song_len)
X_val_offset = pcr.fmap_offset(X_val_offset, fmap2, song_len)

fmap_j, rmap_j = pcr.get_joint_map(fmap, fmap2)
classes_j = len(set(fmap_j.keys()))

X_train_joint = pcr.zip_(X_train_melody, X_train_offset, rmap, rmap2, fmap_j)
X_val_joint = pcr.zip_(X_val_melody, X_val_offset, rmap, rmap2, fmap_j)

In [8]:
### ===== Load model ===== ###

model1 = mdl.cnn_varautoencoder(1, 1, classes_j, std=1.0)
model6 = trn.load_model('vae6_CE', model1, device)[0]
model6.eval()

cnn_varautoencoder(
  (encoder): cnn_varencoder(
    (conv1): Conv1d(1, 16, kernel_size=(21,), stride=(1,))
    (conv21): Conv1d(16, 4, kernel_size=(11,), stride=(1,))
    (conv22): Conv1d(16, 4, kernel_size=(11,), stride=(1,))
    (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (relu): ReLU()
    (flat): Flatten(start_dim=1, end_dim=-1)
    (linear): Linear(in_features=160, out_features=20, bias=True)
  )
  (decoder): cnn_vardecoder(
    (tconv1): ConvTranspose1d(4, 16, kernel_size=(11,), stride=(1,))
    (tconv2): ConvTranspose1d(16, 1, kernel_size=(31,), stride=(1,))
    (relu): ReLU()
    (sigmoid): Sigmoid()
    (linear): Linear(in_features=20, out_features=640, bias=True)
    (dropout): Dropout(p=0, inplace=False)
  )
)

In [37]:
idx = 80
base = 12
fname = 'temp'

if idx == None: idx = np.random.randint(0, len(X_train_joint))
recons = model6(X_train_joint[idx])[0].squeeze().detach().numpy()*20.0
melody, duration = pcr.rmap_safe(rmap_j, recons).T

In [38]:
offset = [0]

cumulation = 0
for d in duration:
    cumulation += float(d)
    offset.append(cumulation)
offset = np.array(offset)/3

In [39]:
pcr.gen_stream(melody, offset, base=base, output=True, fname=fname)

<music21.stream.Stream 0x1efc94e7670>