In [1]:
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import wandb
import os
import json

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmacosta[0m (use `wandb login --relogin` to force relogin)


True

In [2]:
TOKENIZER_SAVEDIR = Path('/home/macosta/ttmp/primus-data/primus-mei/mei-tokenizer/')
LM_MODEL_SAVEDIR = Path('/home/macosta/ttmp/primus-models/gpt2-lm-mei/')
Path(LM_MODEL_SAVEDIR).mkdir(exist_ok=True)
PRIMUS_TXT_FILES = Path('/home/macosta/ttmp/primus-data/primus-mei/mei-cleaned/')

In [4]:
[torch.cuda.device(i) for i in range(torch.cuda.device_count())]

[<torch.cuda.device at 0x7ff7d3072dd0>, <torch.cuda.device at 0x7ff7d3072f90>]

In [5]:
torch.cuda.set_device(0)
print('Cuda available: ', torch.cuda.is_available())

Cuda available:  True


In [6]:
torch.cuda.current_device()

0

In [4]:
VOCAB_SIZE = 30000
MAX_LEN = 512
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_SAVEDIR, max_len=MAX_LEN)

file /home/macosta/ttmp/primus-data/primus-mei/mei-tokenizer/config.json not found
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
tokenizer.unk_token = '<unk>'
tokenizer.bos_token = '<s>'
tokenizer.eos_token = '</s>'
tokenizer.add_special_tokens({'pad_token': '<pad>'})

0

In [6]:
ACTUAL_VOCAB_SIZE = len(tokenizer.vocab)

In [29]:
ACTUAL_VOCAB_SIZE

769

In [89]:
def get_quarters(token):
    if token == 'whole':
        return 4
    elif token == 'half':
        return 2
    elif token == 'quarter':
        return 1
    elif token == 'eighth':
        return 0.5
    elif token == 'sixteenth':
        return 0.25
    return 0

def get_sixteenths(token):
    if token == 'whole':
        return 16
    elif token == 'half':
        return 8
    elif token == 'quarter':
        return 4
    elif token == 'eighth':
        return 2
    elif token == 'sixteenth':
        return 1
    return 0

def parse_time_sig(time_sig):
    time_sig = time_sig.split('-')[1]
    if time_sig == 'C':
        return 4, 4
    elif time_sig == 'C/':
        return 2, 2
    elif '/' in time_sig:
        top, bottom = time_sig.split('/')
        return int(top), int(bottom)
    else:
        return 4, 4
    
def get_rhythmic_sequence(tokens):
    sixteenths_per_bar = 0
    sixteenths_left_in_bar = 0
    last_sixteenths_duration = 0
    rhythmic_sequence = []
    for token in tokens:
        if len(token) > 4 and token[:4] == 'time':
            beats_per_bar, denom = parse_time_sig(token)
            sixteenths_per_beat = 16 // denom
            sixteenths_per_bar = beats_per_bar * sixteenths_per_beat
            sixteenths_left_in_bar = sixteenths_per_bar
            last_beat_duration = 0
        elif token == 'barline':
            sixteenths_left_in_bar = sixteenths_per_bar
        elif token == 'dot':
            sixteenths_left_in_bar -= last_sixteenths_duration * 0.5
        elif token == 'dotdot':
            sixteenths_left_in_bar -= last_sixteenths_duration * 0.25
        elif token == '</s>':
            sixteenths_per_bar = 0
            sixteenths_left_in_bar = 0
            last_sixteenths_duration = 0
        else:
            sixteenth_duration = get_sixteenths(token)
            sixteenths_left_in_bar -= sixteenth_duration
            if sixteenth_duration > 0:
                last_sixteenths_duration = sixteenth_duration
        sixteenths_left_in_bar = max(sixteenths_left_in_bar, 0)
        rhythmic_sequence.append(int(sixteenths_left_in_bar))
    return rhythmic_sequence
    
def token_to_action(token):
    if token == 'barline' or token == '</s>' or token == '<s>':
        return ("RESET", 0)
    elif token == 'dot':
        return ("USE_LAST_DURATION", 0.5)
    elif token == 'dotdot':
        return ("USE_LAST_DURATION", 0.25)
    elif len(token) > 14 and token[:14] == 'timeSignature-':
        return ("SET_TIMESIG", parse_time_sig(token))
    else:
        return ("DECREMENT", get_quarters(token))
    
# index_to_token = {v:k for k, v in tokenizer.vocab.items()}
# def index_to_action(index):
#     token = index_to_token[index]
#     return token, token_to_action(token)

# index_to_action_dict = {}
# for index in index_to_token:
#     index_to_action_dict[index] = index_to_action(index)

In [18]:
# with open('/home/macosta/ttmp/primus-models/index-action-dict.json', 'w') as f:
#     json.dump(index_to_action_dict, f)

In [8]:
config = GPT2Config(
    vocab_size=ACTUAL_VOCAB_SIZE,
    n_positions=MAX_LEN,
    n_head=12,
#     index_to_action = index_to_action
)

model = GPT2LMHeadModel(config=config)
print('Num parameters:', model.num_parameters())

TYING WEIGHTS!
GETTING EMBEDDINGS!
Num parameters: 86039808


In [9]:
class CustomDataset(Dataset):
    def __init__(self, src_files, tokenizer, max_length):
        self.examples = []
        for src_file in tqdm(src_files):
            words = src_file.read_text(encoding="utf-8")
            words = words.split()
            words = ['<s>'] + words + ['</s>']
            for i in range(0, len(words), max_length):
                word_string = ' '.join(words[i:i+max_length])
                tokenized = tokenizer.encode(word_string, max_length=max_length, padding='max_length')
                assert(len(tokenized) == max_length)
                self.examples.append(tokenized)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return torch.tensor(self.examples[i])

In [135]:
def safecheck_tokens(tokens):
    if 'thirty_second' in tokens or 'sixty_fourth' in tokens:
        return False
    timesig = [t for t in tokens if t[:4] == 'time'][0]
    beats_per_bar, denom = parse_time_sig(timesig)
    sixteenths_per_beat = 16 // denom
    sixteenths_per_bar = beats_per_bar * sixteenths_per_beat
    return sixteenths_per_bar <= 16

In [139]:
class RhythmicDataset(Dataset):
    def __init__(self, src_files, tokenizer, max_length):
        self.examples = []
        for src_file in tqdm(src_files):
            words = src_file.read_text(encoding="utf-8")
            words = words.split()
            words = ['<s>'] + words + ['</s>']
            if not safecheck_tokens(words):
                continue
            rhythmic_sequence = get_rhythmic_sequence(words)
            # Chunk in groups of max_length
            for i in range(0, len(words), max_length):
                word_string = ' '.join(words[i:i+max_length])
                tokenized = tokenizer.encode(word_string, max_length=max_length, padding='max_length')
                relevant_rhythms = rhythmic_sequence[i:i+max_length]
                # Pad rhythm information as necessary
                if len(relevant_rhythms) < max_length:
                    relevant_rhythms += [0] * (max_length - len(relevant_rhythms))
                # Bit shift rhythm and add token
                rhythm_encoded = [(r << 10) + t for r, t in zip(relevant_rhythms, tokenized)]
                assert(len(rhythm_encoded) == max_length)
                self.examples.append(rhythm_encoded)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return torch.tensor(self.examples[i])

In [10]:
def create_train_test_datasets(tokenizer, max_length, fraction=1.0, test_size=0.1):
    src_files = list(Path(PRIMUS_TXT_FILES).glob("**/*.mei"))
    src_files = src_files[:int(len(src_files) * fraction)]
    split_index = int(len(src_files) * (1 - test_size))
    train_files = src_files[:split_index]
    test_files = src_files[split_index:]
    train_dataset = CustomDataset(train_files, tokenizer, max_length=max_length)
    test_dataset = CustomDataset(test_files, tokenizer, max_length=max_length)
    return train_dataset, test_dataset

In [11]:
train_dataset, test_dataset = create_train_test_datasets(tokenizer, MAX_LEN, fraction=1, test_size=0.05)

100%|█████████████████████████████████████| 8328/8328 [00:03<00:00, 2118.65it/s]
100%|███████████████████████████████████████| 439/439 [00:00<00:00, 2096.10it/s]


In [28]:
# train_dataset.__getitem__(102)

tensor([ 0, 47,  4, 45,  4, 46,  4, 48,  4, 37, 55, 81, 57, 66,  4, 50,  4, 54,
        52, 65, 51, 11, 12, 11, 53, 13,  5, 44,  4, 42,  4, 49,  4, 21, 20, 22,
         4, 19, 12, 11,  4, 18, 12, 11,  4,  9, 36, 11,  6, 10,  8, 10,  7, 30,
         5,  9,  6, 14,  8, 10,  7, 30,  5,  9,  6, 10,  8, 10,  7, 30,  5,  9,
         6, 10,  8, 10,  7, 30,  4, 63, 61, 33,  5, 62,  4, 15,  4, 17,  4, 16,
         4, 21, 20, 22,  4, 19, 12, 11,  4, 18, 12, 11,  4,  9,  6, 10,  8, 10,
         7, 29,  5,  9,  6, 10,  8, 10,  7, 30,  5,  9,  6, 10, 88, 87,  8, 10,
         7, 29,  5, 35,  6, 10,  5, 15,  4, 17,  4, 16,  4, 43,  4, 41,  4, 39,
         4, 38,  4, 40,  4,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1, 

In [13]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [14]:
training_args = TrainingArguments(
    output_dir=LM_MODEL_SAVEDIR,
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    save_steps=10000,
    logging_steps=3000,
    evaluation_strategy="steps",
    eval_steps=3000,
    save_total_limit=1,
    prediction_loss_only=False,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [30]:
ret = trainer.train()

NameError: name 'trainer' is not defined

In [None]:
trainer.save_model(LM_MODEL_SAVEDIR)