In [1]:
import torch
import pickle
from model import *

In [2]:
# load the model
model = torch.load('SmallMusicModel.pth')
model.eval();

In [3]:
# decoder
with open('encoder.pickle', 'rb') as f:
    encoder = pickle.load(f)
decoder = {v:k for k,v in encoder.items()}

# generate from the model
def new_tokens(max_new_tokens=100):
    context = torch.ones((1, 1), dtype=torch.long, device=device)
    return model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()

# decoder (prints and returns for different uses)
def decode(generated_list, print_list=True, return_list=False):
    decoded_list = [decoder[x] for x in generated_list]
    output_list = []
    temp_list = []
    for token in decoded_list:
        if token == '<EOS>':
            if len(temp_list) > 1 and len(temp_list) < 8:  # no silly chord progressions
                output_list += [temp_list]
            temp_list = []
        else:
            temp_list.append(token)
    if print_list:
        for progression in output_list:
            for token in progression:
                print(f'{token}', end='  ')
            print()
    return output_list if return_list else None

def generate(n=100):
    """generates approximately n tokens worth of chord progressions"""
    return decode(new_tokens(n))

In [4]:
# look at some output (1000 training steps)
generate()

Adim7  E7b9  
C  G  C  
C  G  Am  F  G  
C  Am7  Fadd9  G  
F  C  G  
C  F  G  C  
C  F  G  C  
C  Am  G  F  
F  G  Am  G  
Am  F  C  G  
E  C7  
Am  F  C  G  
C  Dm  G7  
C  G  Am  G  F  
F  C  G  Am  
Am  C  G  
C  Am  Dm  F  


In [4]:
# look at some output (10000 training steps / overfit)
generate()

Adim7  E7b9  Am7  
Am  G  F  C  G  
F  C  G  Dm  F  C  G  
Dm  Em  F  G  
G  Fadd9  C  
G  Am  F  G  C  
F  C  Gsus4  
C  Am  F  
C  Am  Dm  G7b9  
F  Am  G  
Cadd9  C  G  
C  G  Dm  F  
C  Dm11  Fsus2  
C  Am  Dm  F  G  
Am  Am7  Am6  Am  F  G  
C  Am  G  F  
Am  F  C  G  
