In [1]:
import torch
import pickle
from model import *
import pandas as pd
import numpy as np

In [2]:
# load the models
m1k = torch.load('data/SmallMusicModel_1000steps_lr3.pth')  # 1000 training steps, lr=1e-3
m1k.eval();
m10k = torch.load('data/SmallMusicModel_10000steps_lr3_overfit.pth')  # 10000 training steps, lr=1e-3, overfit
m10k.eval();
m4 = torch.load('data/SmallMusicModel_10000steps_lr4.pth')  # 10000 training steps, lr=1e-4
m4.eval();
m4u = torch.load('data/SmallMusicModel_1000steps_lr4_underfit.pth')  # 1000 training steps, lr=1e-4, underfit
m4u.eval();

# make a list of the machine learning models
ml_models = [m1k, m10k, m4, m4u]

In [3]:
# decoder
with open('data/encoder.pickle', 'rb') as f:
    encoder = pickle.load(f)
decoder = {v:k for k,v in encoder.items()}

# generate from the model
def new_tokens(model=ml_models[0], n=100, random_start=False):
    random_token = np.random.choice(list(decoder.keys())) if random_start else 2
    context = torch.tensor([[random_token]], dtype=torch.long, device=device)
    return model.generate(context, max_new_tokens=n)[0].tolist()

# decoder (prints and returns for different uses)
def decode(generated_list, print_list=True, return_list=False):
    decoded_list = [decoder[x] for x in generated_list]
    output_list = []
    temp_list = []
    for token in decoded_list:
        if token == '<EOS>':
            if len(temp_list) > 1 and len(temp_list) < 9:  # ignore short and long progressions
                output_list += [temp_list]
            temp_list = []
        else:
            temp_list.append(token)
    if print_list:
        for progression in output_list:
            for token in progression:
                print(f'{token}', end='  ')
            print()
    return output_list if return_list else None

def generate(model=ml_models[0], n=100, print_list=True, return_list=False):
    """generates approximately n tokens worth of chord progressions"""
    return decode(new_tokens(model=model, n=n), print_list=print_list, return_list=return_list)

def create_chord_progressions(model=ml_models[0], n=1000):
    # relatively slow: >2s / 100 tokens, batches of 1000 would be about 22s each
    list_of_lists = generate(model=model, n=n, print_list=False, return_list=True)
    return pd.DataFrame({'chord_progression': list_of_lists})

In [4]:
# how many chord progressions does our training set have?
df_train = pd.read_pickle('data/df_clean.pickle')[['song', 'artist', 'chord_progression_C']].rename({'chord_progression_C': 'chord_progression'}, axis=1)
df_train = df_train[(df_train.chord_progression.apply(len) > 1) & (df_train.chord_progression.apply(len) < 9)].reset_index(drop=True)
len_dataset = df_train.shape[0]
len_dataset

16348

In [5]:
# create data for exploratory data analysis
def create_data(model=ml_models[0], total_rows=len_dataset):
    df = create_chord_progressions(model=model)
    counter = 0
    n_update = 5
    while df.shape[0] < 17600:
        if counter % n_update == 0:  # status update
            print(df.shape[0], 'chord progressions generated')
        df = pd.concat([df, create_chord_progressions(model=model)])
        counter += 1
    return df

In [6]:
%%time
df_m1k = create_data(m1k)
df_m1k.to_pickle('data/df_m1k.pickle')

192 chord progressions generated
1133 chord progressions generated
2066 chord progressions generated
3031 chord progressions generated
3947 chord progressions generated
4906 chord progressions generated
5831 chord progressions generated
6768 chord progressions generated
7701 chord progressions generated
8663 chord progressions generated
9599 chord progressions generated
10563 chord progressions generated
11512 chord progressions generated
12452 chord progressions generated
13399 chord progressions generated
14337 chord progressions generated
15286 chord progressions generated
16203 chord progressions generated
17148 chord progressions generated
Wall time: 27min 56s


In [7]:
%%time
df_m10k = create_data(m10k)
df_m10k.to_pickle('data/df_m10k.pickle')

174 chord progressions generated
1008 chord progressions generated
1877 chord progressions generated
2737 chord progressions generated
3596 chord progressions generated
4461 chord progressions generated
5313 chord progressions generated
6154 chord progressions generated
7006 chord progressions generated
7868 chord progressions generated
8722 chord progressions generated
9600 chord progressions generated
10457 chord progressions generated
11293 chord progressions generated
12185 chord progressions generated
13046 chord progressions generated
13907 chord progressions generated
14786 chord progressions generated
15646 chord progressions generated
16507 chord progressions generated
17363 chord progressions generated
Wall time: 30min 12s


In [8]:
%%time
df_m4 = create_data(m4)
df_m4.to_pickle('data/df_m4.pickle')

183 chord progressions generated
1064 chord progressions generated
1963 chord progressions generated
2862 chord progressions generated
3746 chord progressions generated
4651 chord progressions generated
5545 chord progressions generated
6435 chord progressions generated
7328 chord progressions generated
8222 chord progressions generated
9122 chord progressions generated
10017 chord progressions generated
10904 chord progressions generated
11781 chord progressions generated
12657 chord progressions generated
13551 chord progressions generated
14428 chord progressions generated
15309 chord progressions generated
16185 chord progressions generated
17081 chord progressions generated
Wall time: 29min 2s


In [9]:
%%time
df_m4u = create_data(m4u)
df_m4u.to_pickle('data/df_m4u.pickle')

191 chord progressions generated
1163 chord progressions generated
2100 chord progressions generated
3056 chord progressions generated
3998 chord progressions generated
4941 chord progressions generated
5880 chord progressions generated
6840 chord progressions generated
7778 chord progressions generated
8718 chord progressions generated
9682 chord progressions generated
10633 chord progressions generated
11588 chord progressions generated
12519 chord progressions generated
13481 chord progressions generated
14417 chord progressions generated
15366 chord progressions generated
16302 chord progressions generated
17259 chord progressions generated
Wall time: 27min 37s
