In [None]:
import torch
from utils import CharRNN, sample, seq2csv
import matplotlib.pyplot as plt
from pprint import pprint

# Load Model

In [None]:
checkpoint = torch.load(open('model/epoch_13158.pth', 'rb'), map_location='cpu')    
net = CharRNN(checkpoint['tokens'], n_hidden=checkpoint['n_hidden'], n_layers=checkpoint['n_layers'])
net.load_state_dict(checkpoint['state_dict'])

In [None]:
loss_history = checkpoint['loss_history']
plt.figure(figsize=(16,5))
plt.semilogy(range(len(loss_history)), loss_history)
plt.xlabel("epoch")
plt.ylabel("loss");

# Compose Music

In [None]:
# Config
fname = 'mymusic'    # File save name
prime = "A4-512-512" # Prime for the RNN
top_k = 3            # Take top k prediction to randomly choose from
compose_len = 1500   # Length of sequence to compose

channel = [0]        # MIDI Channels

In [None]:
seqs = {}
idx_retry = 0
while True:
    assert max(channel) <= 15
    try:
        for i in range(len(channel)):
            seq = sample(net, compose_len, prime=prime, top_k=top_k)
            seq = " ".join(seq.split()[:-1])
            seqs[i+1] = seq
        seq2csv(seqs, fname, channel)
        pprint(seqs)
        break
    except:
        idx_retry += 1
        print(f"Retry music composing... [{idx_retry}]")
        if idx_retry == 10:
            print("Music composition failed. Try to train the model longer")            
            break