In [None]:
hparams = {
    "n_layer": 4,
    "n_head": 4,
    "embedding_dim": 32,
    "max_seq_length": 2048,
    "lr": 1e-4,
    "batch_size": 8,
    "epoch_length": 100,
    "epochs": 100,
}
hparams["iters"] = hparams["epochs"] * hparams["epoch_length"]

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import requests
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import transformers
from tqdm.notebook import tqdm


# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
r = requests.get('https://raw.githubusercontent.com/adactio/TheSession-data/master/json/tunes.json')
assert r.status_code == 200
data = pd.DataFrame(json.loads(r.text))

In [None]:
# Strip all whitespaces from abc
data['abc'] = data['abc'].map(lambda text: text.replace(" ", "").replace("\r", "").replace("\n", "").replace("\t", "").replace("\x14", "").replace("\x1a", "").replace("\xa0", "").replace(u"\u2028", ""))

In [None]:
dictionary = list(data['abc'].map(lambda text: [char for char in text]))
dictionary = [item for sublist in dictionary for item in sublist]
dictionary = ['<bos>', '<eos>'] + list(np.unique(dictionary))

In [None]:
def tensor_to_abc(data):
    return ''.join([dictionary[x] for x in data if x != -100 and x != 0 and x != 1])

def abc_to_list(text):
    return [dictionary.index('<bos>')] + [dictionary.index(char) for char in text] + [dictionary.index('<eos>')]

In [None]:
features = data['abc'].map(lambda text:  abc_to_list(text))
features = [x for x in features if len(x) <= hparams['max_seq_length']]
features.sort(reverse=True, key=lambda x: len(x))
features = torch.nn.utils.rnn.pad_sequence([torch.tensor(x) for x in features], batch_first=True, padding_value=-100)

In [None]:
features.shape

In [None]:
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(
    vocab_size = len(dictionary),
    n_embd = hparams["embedding_dim"],
    n_layer = hparams["n_layer"],
    n_head = hparams["n_head"],
    n_positions = hparams['max_seq_length'],
    n_ctx = hparams['max_seq_length']
))
optim = torch.optim.Adam(model.parameters(), lr=hparams["lr"])

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
features = features.to(device) # The entire dataset fits on a gpu easily

In [None]:
model.train()
for i in tqdm(range(0, hparams['iters'])):
    indices = np.random.choice(len(features), hparams['batch_size'])
    batch = features[indices].clone()
    mask = batch == -100
    batch[mask] = 0
    
    optim.zero_grad()
    loss, predictions, past = model.forward(batch, attention_mask=mask, labels=features[indices])
    loss.backward()
    optim.step()
    
    if i % hparams['epoch_length'] == 0:
        tqdm.write('Epoch %d, loss %f' % (i // hparams['epoch_length'], loss.cpu()))
        torch.save(model.state_dict(), '/kaggle/working/model.pth')



In [None]:
def synthesize(starting_sequence):
    model.eval()
    with torch.no_grad():
        starting_sequence_tensor = torch.tensor(abc_to_list(starting_sequence)).unsqueeze(0).to(device)
        pred, _ = model.forward(starting_sequence_tensor)
        pred = pred.cpu()[0]
        return starting_sequence + tensor_to_abc([x.argmax() for x in pred])

In [None]:
synthesize("aabBB")