In [1]:
import torch
import pickle
from model import *
import pandas as pd
import numpy as np

In [2]:
# load the models
m1k = torch.load('data/SmallMusicModel_1000steps_lr3.pth')  # 1000 training steps, lr=1e-3
m1k.eval();
m10k = torch.load('data/SmallMusicModel_10000steps_lr3_overfit.pth')  # 10000 training steps, lr=1e-3, overfit
m10k.eval();
m4 = torch.load('data/SmallMusicModel_10000steps_lr4.pth')  # 10000 training steps, lr=1e-4
m4.eval();
m4u = torch.load('data/SmallMusicModel_1000steps_lr4_underfit.pth')  # 1000 training steps, lr=1e-4, underfit
m4u.eval();

# make a list of the machine learning models
ml_models = [m1k, m10k, m4, m4u]

In [3]:
# decoder
with open('data/encoder.pickle', 'rb') as f:
    encoder = pickle.load(f)
decoder = {v:k for k,v in encoder.items()}

# generate from the model
def new_tokens(model=ml_models[0], n=100, random_start=False):
    random_token = np.random.choice(list(decoder.keys())) if random_start else 2
    context = torch.tensor([[random_token]], dtype=torch.long, device=device)
    return model.generate(context, max_new_tokens=n)[0].tolist()

# decoder (prints and returns for different uses)
def decode(generated_list, print_list=True, return_list=False):
    decoded_list = [decoder[x] for x in generated_list]
    output_list = []
    temp_list = []
    for token in decoded_list:
        if token == '<EOS>':
            if len(temp_list) > 1 and len(temp_list) < 8:  # no silly chord progressions
                output_list += [temp_list]
            temp_list = []
        else:
            temp_list.append(token)
    if print_list:
        for progression in output_list:
            for token in progression:
                print(f'{token}', end='  ')
            print()
    return output_list if return_list else None

def generate(model=ml_models[0], n=100, print_list=True, return_list=False):
    """generates approximately n tokens worth of chord progressions"""
    return decode(new_tokens(model=model, n=n), print_list=print_list, return_list=return_list)

def create_chord_progressions(model=ml_models[0], n=1000):
    # relatively slow: >2s / 100 tokens, batches of 1000 would be about 22s each
    list_of_lists = generate(model=model, n=n, print_list=False, return_list=True)
    return pd.DataFrame({'chord_progression': list_of_lists})

In [4]:
# how many chord progressions does our training set have?
pd.read_pickle('data/df_clean.pickle').shape

(17619, 5)

In [5]:
# create data for exploratory data analysis
def create_data(model=ml_models[0], total_rows=17600):
    df = create_chord_progressions(model=model)
    while df.shape[0] < 17600:
        print(df.shape[0], 'chord progressions generated')  # status update
        df = pd.concat([df, create_chord_progressions(model=model)])
    return df

In [6]:
%%time
df_m1k = create_data(m1k)
df_m1k.to_pickle('data/df_m1k.pickle')

189 chord progressions generated
376 chord progressions generated
561 chord progressions generated
749 chord progressions generated
927 chord progressions generated
1113 chord progressions generated
1290 chord progressions generated
1470 chord progressions generated
1650 chord progressions generated
1827 chord progressions generated
2019 chord progressions generated
2199 chord progressions generated
2374 chord progressions generated
2565 chord progressions generated
2757 chord progressions generated
2951 chord progressions generated
3135 chord progressions generated
3300 chord progressions generated
3480 chord progressions generated
3657 chord progressions generated
3830 chord progressions generated
4021 chord progressions generated
4212 chord progressions generated
4397 chord progressions generated
4578 chord progressions generated
4773 chord progressions generated
4962 chord progressions generated
5146 chord progressions generated
5331 chord progressions generated
5519 chord progress

In [10]:
%%time
df_m10k = create_data(m10k)
df_m10k.to_pickle('data/df_m10k.pickle')

180 chord progressions generated
344 chord progressions generated
510 chord progressions generated
679 chord progressions generated
852 chord progressions generated
1012 chord progressions generated
1166 chord progressions generated
1320 chord progressions generated
1503 chord progressions generated
1687 chord progressions generated
1855 chord progressions generated
2028 chord progressions generated
2189 chord progressions generated
2347 chord progressions generated
2527 chord progressions generated
2688 chord progressions generated
2854 chord progressions generated
3042 chord progressions generated
3207 chord progressions generated
3375 chord progressions generated
3538 chord progressions generated
3711 chord progressions generated
3873 chord progressions generated
4037 chord progressions generated
4200 chord progressions generated
4369 chord progressions generated
4535 chord progressions generated
4700 chord progressions generated
4866 chord progressions generated
5029 chord progress

In [11]:
%%time
df_m4 = create_data(m4)
df_m4.to_pickle('data/df_m4.pickle')

171 chord progressions generated
343 chord progressions generated
516 chord progressions generated
695 chord progressions generated
865 chord progressions generated
1039 chord progressions generated
1200 chord progressions generated
1374 chord progressions generated
1558 chord progressions generated
1726 chord progressions generated
1901 chord progressions generated
2071 chord progressions generated
2231 chord progressions generated
2408 chord progressions generated
2574 chord progressions generated
2761 chord progressions generated
2923 chord progressions generated
3099 chord progressions generated
3262 chord progressions generated
3431 chord progressions generated
3623 chord progressions generated
3792 chord progressions generated
3961 chord progressions generated
4123 chord progressions generated
4290 chord progressions generated
4456 chord progressions generated
4608 chord progressions generated
4780 chord progressions generated
4957 chord progressions generated
5136 chord progress

In [9]:
%%time
df_m4u = create_data(m4u)
df_m4u.to_pickle('data/df_m4u.pickle')

183 chord progressions generated
373 chord progressions generated
566 chord progressions generated
751 chord progressions generated
940 chord progressions generated
1131 chord progressions generated
1309 chord progressions generated
1489 chord progressions generated
1680 chord progressions generated
1868 chord progressions generated
2046 chord progressions generated
2232 chord progressions generated
2415 chord progressions generated
2597 chord progressions generated
2771 chord progressions generated
2949 chord progressions generated
3132 chord progressions generated
3322 chord progressions generated
3508 chord progressions generated
3692 chord progressions generated
3876 chord progressions generated
4058 chord progressions generated
4248 chord progressions generated
4429 chord progressions generated
4620 chord progressions generated
4801 chord progressions generated
4974 chord progressions generated
5166 chord progressions generated
5355 chord progressions generated
5536 chord progress