## Importing the Required Libraries

https://deepai.org/publication/general-purpose-long-context-autoregressive-modeling-with-perceiver-ar

In [5]:
import gzip
import random

import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from perceiver_ar_pytorch import PerceiverAR
from perceiver_ar_pytorch.autoregressive_wrapper import AutoregressiveWrapper

from sklearn.model_selection import train_test_split

import re

from pychord import Chord
import json

import pickle

## Defining Constants and Helper Functions

In [2]:
# constants

NUM_BATCHES = int(1e4)
BATCH_SIZE = 64
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 2
SEQ_LEN = 4
PREFIX_SEQ_LEN = 1

In [3]:
# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data


def decode_token(token):
    return str(chr(max(32, token)))


def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

'''
model = PerceiverAR(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 8,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)
'''

model = PerceiverAR(
    num_tokens = 121, #974 if not ignoring the different bass note
    dim = 128,
    depth = 16,
    heads = 8,
    dim_head = 64,
    cross_attn_dropout = 0.6,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)


model = AutoregressiveWrapper(model)
model.cuda()

AutoregressiveWrapper(
  (net): PerceiverAR(
    (token_emb): Embedding(121, 128)
    (pos_emb): Embedding(4, 128)
    (rotary_pos_emb): RotaryEmbedding()
    (perceive_layers): ModuleList(
      (0): ModuleList(
        (0): CausalPrefixAttention(
          (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (context_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_q): Linear(in_features=128, out_features=512, bias=False)
          (to_kv): Linear(in_features=128, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=128, bias=True)
        )
        (1): Sequential(
          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=128, out_features=512, bias=False)
          (2): GELU(approximate=none)
          (3): Dropout(p=0.0, inplace=False)
          (4): Linear(in_features=512, out_features=128, bias=False)
    

## Loading and Preprocessing the Data

### An approach with own embeddings (that specify the notes that make up a chord)

In [4]:
temp = open('data/soul_chords.txt','r').read().splitlines()

In [5]:
# Load the lookup for pitch to number translation
with open('data/pitch_embedding.json') as d:
    pitches = json.load(d)

In [6]:
# Replace the diminished and augmented chord notation to work with the PyChord library
chords = [x.replace('o', 'dim') for x in temp]
chords = [x.replace('+', 'aug') for x in chords]

#### Ignoring the different bass notes (i.e., assuming the root is in the base) to reduce chord variety

In [10]:
chords = [re.sub('/.*', '', x) for x in chords]

In [11]:
def embed_chord(chord):
  if chord == 'N':
    return [0.0, 0.0, 0.0, 0.0, 0.0]
  x = chord
  if 'E#' in x:
    x = x.replace("E#", "F")
  if 'B#' in x:
    x = x.replace("B#", "C")
  try:
    x = [int(pitches.get(item,item))  for item in Chord(x).components_with_pitch(root_pitch=3)]
    x += [0] * (5 - len(x))
  except:
    x= [0,0,0,0,0]
  return x

In [13]:
len(set(chords)) #unique chords

120

In [14]:
embeddings_index = {ch: embed_chord(ch) for ch in list(set(chords))}
list(embeddings_index.items())[0] # an arbitrary example

('G', [8, 12, 15, 0, 0])

In [15]:
# a reverse lookup function to get chords from embeddings 
inv_embeddings = {" ".join([str(x) for x in v]): k for k, v in embeddings_index.items()}

In [49]:
inv_embeddings

{'12 16 19 22 0': 'B7',
 '5 9 13 0 0': 'Eaug',
 '9 13 17 0 0': 'G#aug',
 '3 8 10 0 0': 'Dsus4',
 '11 14 17 20 0': 'A#dim7',
 '3 7 10 14 0': 'DM7',
 '10 14 18 0 0': 'Aaug',
 '2 5 8 0 0': 'C#dim',
 '12 15 19 0 0': 'Bm',
 '4 9 11 0 0': 'D#sus4',
 '10 14 17 20 0': 'A7',
 '5 9 12 0 0': 'E',
 '8 12 15 0 0': 'G',
 '2 6 9 13 0': 'C#M7',
 '5 9 12 15 0': 'E7',
 '10 14 17 21 0': 'AM7',
 '3 7 11 0 0': 'Daug',
 '6 9 13 16 0': 'Fm7',
 '1 6 8 0 0': 'Csus4',
 '12 16 20 0 0': 'Baug',
 '9 12 15 0 0': 'G#dim',
 '1 3 8 0 0': 'Csus2',
 '4 7 10 0 0': 'D#dim',
 '7 10 13 0 0': 'F#dim',
 '2 4 9 0 0': 'C#sus2',
 '7 11 14 0 0': 'F#',
 '6 9 13 0 0': 'Fm',
 '7 10 13 16 0': 'F#dim7',
 '11 15 19 0 0': 'A#aug',
 '6 9 12 15 0': 'Fdim7',
 '1 5 9 0 0': 'Caug',
 '4 8 12 0 0': 'D#aug',
 '4 8 11 15 0': 'D#M7',
 '5 9 12 16 0': 'EM7',
 '11 13 18 0 0': 'A#sus2',
 '9 13 16 0 0': 'G#',
 '12 15 19 22 0': 'Bm7',
 '11 14 18 21 0': 'A#m7',
 '9 12 16 19 0': 'G#m7',
 '11 14 17 0 0': 'A#dim',
 '5 7 12 0 0': 'Esus2',
 '8 11 14 17 0': '

In [16]:
chords_emb = [embed_chord(ch) for ch in chords]

In [17]:
chords_emb = np.array(chords_emb, dtype=np.uint8)
chords_emb

array([[11, 15, 18, 21,  0],
       [ 6,  9, 13, 16,  0],
       [ 1,  4,  8, 11,  0],
       ...,
       [ 6, 10, 13, 17,  0],
       [11, 15, 18, 22,  0],
       [11, 15, 18,  0,  0]], dtype=uint8)

In [18]:
trX, vaX = train_test_split(chords_emb, test_size=0.3, random_state=42)
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

In [20]:
def decode_tokens(tokens):
    tmp = tokens.tolist()
    chords = []
    for ch in tmp:
        chords.append(" ".join([str(x) for x in ch]))
    chords = [inv_embeddings[x] for x in chords]
    
    return chords

### An approach without custom embeddings (letting the model do its thing)

In [6]:
temp = open('data/soul_chords.txt','r').read().splitlines()
temp = [re.sub('/.*', '', x) for x in temp]

lookup_set = list(set(temp))
lookup_nr = list(range(0,len(lookup_set)))
lookup = {lookup_set[nr]: nr for nr in lookup_nr}

inv_lookup = {v: k for k, v in lookup.items()}

X = [lookup[x] for x in temp]
trX, vaX = train_test_split(np.array(X, dtype=np.uint8), test_size=0.30, random_state=42)
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

In [8]:
# Load the dataset (chords from MIDIs) in pickle format
chords = pickle.load(open("data/soul_chords.pickle", "rb"))

chords = [[x[0].replace('o', 'dim') for x in song] for song in chords]
chords = [[x.replace('+', 'aug') for x in song] for song in chords]

random.shuffle(chords)

chords = [x for y in chords for x in y]

temp  = [re.sub('/.*', '', x) for x in chords]

lookup_set = list(set(temp))
lookup_nr = list(range(0,len(lookup_set)))
lookup = {lookup_set[nr]: nr for nr in lookup_nr}

inv_lookup = {v: k for k, v in lookup.items()}
X = [lookup[x] for x in temp]

split = int(len(X)*0.8)
trX, vaX = np.array(X[0:split], dtype=np.uint8),np.array(X[split:], dtype=np.uint8)
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

In [8]:
chords = [[x[0].replace('o', 'dim') for x in song] for song in chords]
chords = [[x.replace('+', 'aug') for x in song] for song in chords]

In [9]:
print(len(set(temp)))

120


In [10]:
print(data_train[0], len(data_train))
print(data_val[0], len(data_val))

tensor(41, dtype=torch.uint8) 134909
tensor(12, dtype=torch.uint8) 33728


In [11]:
def decode_tokens(tokens):
    tmp = tokens.tolist()
    chords = [inv_lookup[x] for x in tmp]
    
    return chords

## Training Setup

In [12]:
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len


In [13]:
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

In [14]:
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [15]:
# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    # print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        print(f"training loss: {loss.item()}")
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print((prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        #print(sample[0])
        output_str = decode_tokens(sample[0])
        print(output_str)

training:   0%|                                                                              | 0/10000 [00:00<?, ?it/s]

training loss: 5.448845386505127
validation loss: 5.109672546386719
(['B7', 'EM7', 'E', 'B7'], '****************************************************************************************************')
['Adim7', 'C#dim']


training:   1%|▋                                                                    | 92/10000 [00:20<35:40,  4.63it/s]

training loss: 3.5636377334594727
validation loss: 3.5302047729492188


training:   2%|█▎                                                                  | 192/10000 [00:40<33:49,  4.83it/s]

training loss: 3.520721197128296
validation loss: 3.5595197677612305


training:   3%|█▉                                                                  | 292/10000 [01:00<33:02,  4.90it/s]

training loss: 3.3828325271606445
validation loss: 3.533877372741699


training:   4%|██▋                                                                 | 392/10000 [01:20<32:29,  4.93it/s]

training loss: 3.3905887603759766
validation loss: 3.588961362838745


training:   5%|███▎                                                                | 492/10000 [01:41<32:09,  4.93it/s]

training loss: 3.164304494857788
validation loss: 3.4512412548065186
(['C#M7', 'Fm7', 'Fm', 'A#dim'], '****************************************************************************************************')
['G#7', 'C#M7']


training:   6%|████                                                                | 592/10000 [02:01<31:53,  4.92it/s]

training loss: 3.1034698486328125
validation loss: 3.5375072956085205


training:   7%|████▋                                                               | 692/10000 [02:21<31:36,  4.91it/s]

training loss: 3.2030084133148193
validation loss: 3.5593466758728027


training:   8%|█████▍                                                              | 792/10000 [02:42<31:06,  4.93it/s]

training loss: 3.371645927429199
validation loss: 3.1446969509124756


training:   9%|██████                                                              | 892/10000 [03:02<30:48,  4.93it/s]

training loss: 3.247957468032837
validation loss: 3.6569488048553467


training:  10%|██████▋                                                             | 992/10000 [03:22<30:17,  4.96it/s]

training loss: 3.3002641201019287
validation loss: 3.2886171340942383
(['Bdim', 'F', 'CM7', 'C'], '****************************************************************************************************')
['CM7', 'C7']


training:  11%|███████▎                                                           | 1092/10000 [03:42<30:01,  4.94it/s]

training loss: 3.118952989578247
validation loss: 3.30802059173584


training:  12%|███████▉                                                           | 1192/10000 [04:03<29:43,  4.94it/s]

training loss: 3.17362904548645
validation loss: 3.714118003845215


training:  13%|████████▋                                                          | 1292/10000 [04:23<29:23,  4.94it/s]

training loss: 3.1618587970733643
validation loss: 3.3989508152008057


training:  14%|█████████▎                                                         | 1392/10000 [04:43<29:02,  4.94it/s]

training loss: 3.0798633098602295
validation loss: 3.5752956867218018


training:  15%|█████████▉                                                         | 1492/10000 [05:03<28:39,  4.95it/s]

training loss: 3.082555055618286
validation loss: 3.3012142181396484
(['C7', 'Cm7', 'Adim', 'Cm7'], '****************************************************************************************************')
['Am', 'Adim']


training:  16%|██████████▋                                                        | 1592/10000 [05:24<28:26,  4.93it/s]

training loss: 2.9451732635498047
validation loss: 3.6709909439086914


training:  17%|███████████▎                                                       | 1692/10000 [05:44<28:00,  4.94it/s]

training loss: 3.159874200820923
validation loss: 3.5433223247528076


training:  18%|████████████                                                       | 1792/10000 [06:04<27:41,  4.94it/s]

training loss: 2.9047741889953613
validation loss: 3.395899772644043


training:  19%|████████████▋                                                      | 1892/10000 [06:25<27:30,  4.91it/s]

training loss: 2.7826976776123047
validation loss: 3.841527223587036


training:  20%|█████████████▎                                                     | 1988/10000 [06:46<28:16,  4.72it/s]

training loss: 2.9920806884765625
validation loss: 3.301839590072632
(['A#7', 'C7', 'Dm7', 'Ddim'], '****************************************************************************************************')
['Gm7', 'C7']


training:  21%|█████████████▉                                                     | 2082/10000 [07:06<28:43,  4.59it/s]

training loss: 2.807187080383301
validation loss: 3.462675094604492


training:  22%|██████████████▌                                                    | 2177/10000 [07:27<28:01,  4.65it/s]

training loss: 3.117548942565918
validation loss: 3.445192337036133


training:  23%|███████████████▏                                                   | 2275/10000 [07:47<27:13,  4.73it/s]

training loss: 3.0747222900390625
validation loss: 3.318112373352051


training:  24%|███████████████▉                                                   | 2375/10000 [08:07<26:24,  4.81it/s]

training loss: 2.9711523056030273
validation loss: 3.685460090637207


training:  25%|████████████████▌                                                  | 2474/10000 [08:28<25:50,  4.85it/s]

training loss: 3.1012675762176514
validation loss: 3.586355447769165
(['C', 'A#m', 'F#M7', 'F#dim'], '****************************************************************************************************')
['F#', 'BM7']


training:  26%|█████████████████▏                                                 | 2573/10000 [08:48<25:32,  4.85it/s]

training loss: 2.828500747680664
validation loss: 3.447890043258667


training:  27%|█████████████████▉                                                 | 2672/10000 [09:08<25:02,  4.88it/s]

training loss: 3.01645827293396
validation loss: 3.5151355266571045


training:  28%|██████████████████▌                                                | 2771/10000 [09:29<24:51,  4.85it/s]

training loss: 2.740443468093872
validation loss: 3.56131649017334


training:  29%|███████████████████▏                                               | 2871/10000 [09:49<24:25,  4.87it/s]

training loss: 2.535263776779175
validation loss: 3.3617160320281982


training:  30%|███████████████████▉                                               | 2969/10000 [10:10<24:10,  4.85it/s]

training loss: 2.7249205112457275
validation loss: 3.6015329360961914
(['Bm7', 'F#m', 'AM7', 'Bm7'], '****************************************************************************************************')
['GM7', 'AM7']


training:  31%|████████████████████▌                                              | 3064/10000 [10:30<24:33,  4.71it/s]

training loss: 2.9728100299835205
validation loss: 3.5374255180358887


training:  32%|█████████████████████▏                                             | 3158/10000 [10:51<24:23,  4.68it/s]

training loss: 2.4726014137268066
validation loss: 3.6121997833251953


training:  33%|█████████████████████▊                                             | 3258/10000 [11:11<23:32,  4.77it/s]

training loss: 2.761709451675415
validation loss: 3.3964149951934814


training:  34%|██████████████████████▍                                            | 3356/10000 [11:32<23:09,  4.78it/s]

training loss: 2.655526876449585
validation loss: 3.4931561946868896


training:  35%|███████████████████████▏                                           | 3454/10000 [11:52<22:41,  4.81it/s]

training loss: 2.7532360553741455
validation loss: 3.562941789627075
(['G', 'A', 'C#m7', 'AM7'], '****************************************************************************************************')
['Bm7', 'E']


training:  36%|████████████████████████▏                                          | 3601/10000 [12:22<22:03,  4.83it/s]

training loss: 2.5877771377563477
validation loss: 3.662142515182495


training:  37%|████████████████████████▊                                          | 3700/10000 [12:43<21:45,  4.83it/s]

training loss: 2.692068099975586
validation loss: 3.7908389568328857


training:  38%|█████████████████████████▍                                         | 3797/10000 [13:03<21:26,  4.82it/s]

training loss: 2.394991397857666
validation loss: 3.466324806213379


training:  39%|██████████████████████████                                         | 3895/10000 [13:23<21:03,  4.83it/s]

training loss: 2.6037352085113525
validation loss: 3.428720474243164


training:  40%|██████████████████████████▋                                        | 3992/10000 [13:44<20:51,  4.80it/s]

training loss: 2.3709731101989746
validation loss: 3.40653133392334
(['G7', 'C', 'C', 'F7'], '****************************************************************************************************')
['G', 'F7']


training:  41%|███████████████████████████▍                                       | 4089/10000 [14:04<20:30,  4.80it/s]

training loss: 2.460252523422241
validation loss: 3.487990379333496


training:  42%|████████████████████████████                                       | 4187/10000 [14:24<20:04,  4.83it/s]

training loss: 2.233926773071289
validation loss: 3.6748058795928955


training:  43%|████████████████████████████▋                                      | 4285/10000 [14:44<19:40,  4.84it/s]

training loss: 2.258086919784546
validation loss: 3.6207993030548096


training:  44%|█████████████████████████████▎                                     | 4383/10000 [15:05<19:23,  4.83it/s]

training loss: 2.5057034492492676
validation loss: 3.8787364959716797


training:  45%|██████████████████████████████                                     | 4480/10000 [15:25<19:15,  4.78it/s]

training loss: 2.454847574234009
validation loss: 3.6510846614837646
(['Cm7', 'Gm7', 'Cm', 'Gm7'], '****************************************************************************************************')
['Gm7', 'A#M7']


training:  46%|██████████████████████████████▋                                    | 4575/10000 [15:46<19:09,  4.72it/s]

training loss: 2.417909860610962
validation loss: 3.6396677494049072


training:  47%|███████████████████████████████▎                                   | 4670/10000 [16:06<18:47,  4.73it/s]

training loss: 2.523170232772827
validation loss: 3.5624444484710693


training:  48%|███████████████████████████████▉                                   | 4768/10000 [16:26<18:16,  4.77it/s]

training loss: 2.3122148513793945
validation loss: 3.6167972087860107


training:  49%|████████████████████████████████▌                                  | 4866/10000 [16:46<17:53,  4.78it/s]

training loss: 2.2645606994628906
validation loss: 3.7420387268066406


training:  50%|█████████████████████████████████▎                                 | 4964/10000 [17:07<17:26,  4.81it/s]

training loss: 2.447530746459961
validation loss: 3.7602105140686035
(['C#7', 'F#sus4', 'F#M7', 'F#'], '****************************************************************************************************')
['F#M7', 'G#m']


training:  51%|█████████████████████████████████▉                                 | 5062/10000 [17:27<17:00,  4.84it/s]

training loss: 2.325312852859497
validation loss: 4.13566255569458


training:  52%|██████████████████████████████████▌                                | 5160/10000 [17:47<16:35,  4.86it/s]

training loss: 2.3347623348236084
validation loss: 3.7983596324920654


training:  53%|███████████████████████████████████▏                               | 5259/10000 [18:07<16:09,  4.89it/s]

training loss: 2.1612446308135986
validation loss: 3.666431427001953


training:  54%|███████████████████████████████████▉                               | 5359/10000 [18:27<15:47,  4.90it/s]

training loss: 2.257296323776245
validation loss: 4.083817958831787


training:  55%|████████████████████████████████████▌                              | 5459/10000 [18:48<15:25,  4.91it/s]

training loss: 2.314633369445801
validation loss: 3.737484931945801
(['Fm7', 'G#', 'C#M7', 'G#'], '****************************************************************************************************')
['Cm7', 'G#']


training:  56%|█████████████████████████████████████▏                             | 5559/10000 [19:08<15:04,  4.91it/s]

training loss: 2.161168336868286
validation loss: 3.999681234359741


training:  57%|█████████████████████████████████████▉                             | 5659/10000 [19:28<14:43,  4.91it/s]

training loss: 2.3989531993865967
validation loss: 3.825885772705078


training:  58%|██████████████████████████████████████▌                            | 5759/10000 [19:49<14:22,  4.92it/s]

training loss: 2.401141881942749
validation loss: 3.6086809635162354


training:  59%|███████████████████████████████████████▎                           | 5859/10000 [20:09<14:02,  4.92it/s]

training loss: 2.2056961059570312
validation loss: 3.9318275451660156


training:  60%|███████████████████████████████████████▉                           | 5959/10000 [20:30<13:42,  4.91it/s]

training loss: 2.173673391342163
validation loss: 3.575629949569702
(['CM7', 'Daug', 'GM7', 'CM7'], '****************************************************************************************************')
['G', 'B']


training:  61%|████████████████████████████████████████▌                          | 6059/10000 [20:50<13:22,  4.91it/s]

training loss: 2.057257652282715
validation loss: 3.969864845275879


training:  62%|█████████████████████████████████████████▎                         | 6159/10000 [21:10<13:01,  4.91it/s]

training loss: 2.1024229526519775
validation loss: 4.129100322723389


training:  63%|█████████████████████████████████████████▉                         | 6259/10000 [21:31<12:42,  4.91it/s]

training loss: 2.3087353706359863
validation loss: 3.5848586559295654


training:  64%|██████████████████████████████████████████▌                        | 6358/10000 [21:51<12:22,  4.90it/s]

training loss: 1.9394086599349976
validation loss: 3.948714256286621


training:  65%|███████████████████████████████████████████▎                       | 6458/10000 [22:11<12:01,  4.91it/s]

training loss: 2.239896059036255
validation loss: 3.878621816635132
(['A#m7', 'D#M7', 'G#', 'D#sus2'], '****************************************************************************************************')
['D#', 'D#M7']


training:  66%|███████████████████████████████████████████▉                       | 6558/10000 [22:32<11:40,  4.91it/s]

training loss: 2.0338399410247803
validation loss: 4.300795078277588


training:  67%|████████████████████████████████████████████▌                      | 6658/10000 [22:52<11:19,  4.92it/s]

training loss: 2.1866652965545654
validation loss: 3.9076263904571533


training:  68%|█████████████████████████████████████████████▎                     | 6757/10000 [23:13<11:12,  4.82it/s]

training loss: 2.0373435020446777
validation loss: 3.8474464416503906


training:  69%|█████████████████████████████████████████████▉                     | 6857/10000 [23:33<10:45,  4.87it/s]

training loss: 2.0904974937438965
validation loss: 4.013378620147705


training:  70%|██████████████████████████████████████████████▌                    | 6957/10000 [23:53<10:22,  4.89it/s]

training loss: 1.9146443605422974
validation loss: 3.996220827102661
(['D7', 'GM7', 'Bm', 'CM7'], '****************************************************************************************************')
['G', 'CM7']


training:  71%|███████████████████████████████████████████████▌                   | 7101/10000 [24:25<10:12,  4.73it/s]

training loss: 1.8938900232315063
validation loss: 4.108184337615967


training:  72%|████████████████████████████████████████████████▏                  | 7199/10000 [24:45<09:45,  4.79it/s]

training loss: 2.0644638538360596
validation loss: 3.576327323913574


training:  73%|████████████████████████████████████████████████▉                  | 7296/10000 [25:05<09:30,  4.74it/s]

training loss: 2.1512341499328613
validation loss: 4.227970600128174


training:  74%|█████████████████████████████████████████████████▌                 | 7396/10000 [25:26<08:58,  4.83it/s]

training loss: 2.1400158405303955
validation loss: 3.9992544651031494


training:  75%|██████████████████████████████████████████████████▏                | 7495/10000 [25:46<08:37,  4.84it/s]

training loss: 2.1041297912597656
validation loss: 3.9454281330108643
(['CM7', 'CM7', 'Edim', 'Edim'], '****************************************************************************************************')
['A', 'Edim']


training:  76%|██████████████████████████████████████████████████▊                | 7593/10000 [26:06<08:16,  4.84it/s]

training loss: 1.9909887313842773
validation loss: 4.05290412902832


training:  77%|███████████████████████████████████████████████████▌               | 7692/10000 [26:27<07:53,  4.87it/s]

training loss: 1.9977883100509644
validation loss: 3.8486993312835693


training:  78%|████████████████████████████████████████████████████▏              | 7792/10000 [26:47<07:30,  4.90it/s]

training loss: 1.7639098167419434
validation loss: 4.130814552307129


training:  79%|████████████████████████████████████████████████████▉              | 7892/10000 [27:07<07:09,  4.91it/s]

training loss: 2.1161105632781982
validation loss: 4.178129196166992


training:  80%|█████████████████████████████████████████████████████▌             | 7992/10000 [27:28<06:48,  4.92it/s]

training loss: 1.5935395956039429
validation loss: 4.248739242553711
(['D7', 'D', 'Am7', 'D'], '****************************************************************************************************')
['A', 'A7']


training:  81%|██████████████████████████████████████████████████████▏            | 8092/10000 [27:48<06:28,  4.91it/s]

training loss: 2.1196978092193604
validation loss: 4.388339519500732


training:  82%|██████████████████████████████████████████████████████▉            | 8192/10000 [28:08<06:06,  4.93it/s]

training loss: 1.8074244260787964
validation loss: 3.96450138092041


training:  83%|███████████████████████████████████████████████████████▌           | 8292/10000 [28:28<05:46,  4.93it/s]

training loss: 2.083991765975952
validation loss: 4.403730392456055


training:  84%|████████████████████████████████████████████████████████▏          | 8392/10000 [28:49<05:26,  4.92it/s]

training loss: 1.7692843675613403
validation loss: 4.182260036468506


training:  85%|████████████████████████████████████████████████████████▉          | 8492/10000 [29:09<05:06,  4.92it/s]

training loss: 1.8615082502365112
validation loss: 4.221686840057373
(['Baug', 'B', 'C#aug', 'A#M7'], '****************************************************************************************************')
['A#m', 'A#M7']


training:  86%|█████████████████████████████████████████████████████████▌         | 8592/10000 [29:29<04:45,  4.93it/s]

training loss: 2.0659444332122803
validation loss: 4.189205169677734


training:  87%|██████████████████████████████████████████████████████████▏        | 8692/10000 [29:50<04:25,  4.92it/s]

training loss: 1.8083473443984985
validation loss: 3.9739186763763428


training:  88%|██████████████████████████████████████████████████████████▉        | 8792/10000 [30:10<04:05,  4.92it/s]

training loss: 1.7238572835922241
validation loss: 4.160836219787598


training:  89%|███████████████████████████████████████████████████████████▌       | 8892/10000 [30:30<03:45,  4.92it/s]

training loss: 1.6242971420288086
validation loss: 3.7629919052124023


training:  90%|████████████████████████████████████████████████████████████▏      | 8991/10000 [30:51<03:26,  4.89it/s]

training loss: 1.8526463508605957
validation loss: 4.31276798248291
(['D#m7', 'G#m', 'BM7', 'D#m7'], '****************************************************************************************************')
['D#m', 'C#']


training:  91%|████████████████████████████████████████████████████████████▉      | 9091/10000 [31:11<03:05,  4.91it/s]

training loss: 1.7991302013397217
validation loss: 4.231791019439697


training:  92%|█████████████████████████████████████████████████████████████▌     | 9190/10000 [31:32<02:46,  4.85it/s]

training loss: 1.9418405294418335
validation loss: 3.324223756790161


training:  93%|██████████████████████████████████████████████████████████████▏    | 9288/10000 [31:52<02:26,  4.85it/s]

training loss: 1.7919549942016602
validation loss: 4.010268688201904


training:  94%|██████████████████████████████████████████████████████████████▉    | 9386/10000 [32:13<02:07,  4.80it/s]

training loss: 1.84423828125
validation loss: 4.6943440437316895


training:  95%|███████████████████████████████████████████████████████████████▌   | 9486/10000 [32:33<01:45,  4.87it/s]

training loss: 1.8504031896591187
validation loss: 4.065247058868408
(['E', 'C#m7', 'E', 'CM7'], '****************************************************************************************************')
['B', 'B7']


training:  96%|████████████████████████████████████████████████████████████████▏  | 9585/10000 [32:53<01:25,  4.85it/s]

training loss: 1.8170384168624878
validation loss: 4.431827068328857


training:  97%|████████████████████████████████████████████████████████████████▉  | 9683/10000 [33:14<01:05,  4.86it/s]

training loss: 1.818428874015808
validation loss: 4.111574649810791


training:  98%|█████████████████████████████████████████████████████████████████▌ | 9781/10000 [33:34<00:45,  4.85it/s]

training loss: 1.6337765455245972
validation loss: 3.973446846008301


training:  99%|██████████████████████████████████████████████████████████████████▏| 9879/10000 [33:54<00:24,  4.85it/s]

training loss: 1.9391199350357056
validation loss: 3.879714250564575


training: 100%|██████████████████████████████████████████████████████████████████| 10000/10000 [34:19<00:00,  4.86it/s]


In [16]:
import pickle
pickle.dump(model, open( "model.p", "wb"))

In [45]:
def chord_prog(prompt):
    inp = prompt.split(" ")
    inp = np.array([lookup[x] for x in inp], dtype=np.uint8)
    inp = torch.from_numpy(inp)
    inp = TextSamplerDataset(inp, SEQ_LEN)
    inp = inp[0]
    out = model.generate(inp[None, ...], 4)
    return decode_tokens(out[0])

In [None]:
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = 3

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

In [51]:
for i in range(0,10):
    print(chord_prog("Am7 Dm7 Am7 Dm7"))

['Am7', 'Dm7', 'C7', 'C']
['Am7', 'Dm7', 'Am7', 'Dm7']
['C7', 'C', 'G', 'Em7']
['Am7', 'Dm7', 'Am7', 'Dm7']
['CM7', 'Dm7', 'Edim', 'Gm7']
['Am7', 'Dm7', 'Am7', 'Dm7']
['GM7', 'FM7', 'Em7', 'Gm7']
['C7', 'C', 'F#dim', 'F#dim']
['GM7', 'FM7', 'Em7', 'Gm7']
['Am7', 'Dm7', 'Am7', 'Dm7']
