# mtg language model

let's follow along with [this notebook](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb) to do a training from scratch of our MTG corpus

In [49]:
import functools
import glob
import os
import random
import subprocess

import datasets
import pandas as pd
import tokenizers
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import (BertConfig,
                          BertForPreTraining,
                          BertTokenizerFast,
                          Trainer,
                          TrainingArguments,
                          pipeline, )

from utils import grouper

In [2]:
IS_LOCAL_LAPTOP = True
TRAIN_TOKENIZER = True
MAKE_MLM_NSP_DATASET_INPUTS = True
MAKE_MLM_NSP_DATASET = True
DO_CHUNKED_DATASET_BUILDING = True
SAVE_DATASET_TO_FILE = True
LOAD_DATASET_FROM_FILE = False
DO_PRE_TRAINING = False
DO_SMALL_PRETRAIN_TEST = False

# whether or not we cap the number of card pairs, or we choose or use
# all possible pairs
NUM_MAX_PAIRS = None
# NUM_MAX_PAIRS = 1

In [3]:
if TRAIN_TOKENIZER:
    assert IS_LOCAL_LAPTOP

In [4]:
SPLITS = ['train', 'test', 'validation']
SEED = 1337

# for the chunking of the dataset builder

N_PROCS = os.cpu_count()
N_CHUNKS = 200

In [5]:
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', None)

## get raw text dataset

In [6]:
if IS_LOCAL_LAPTOP:
    import mtg.cards

In [7]:
@functools.lru_cache(None)
def get_cards():
    cards = (mtg.cards.cards_df()
             [['name', 'multiverseId', 'scryfallId', 'type', 'manaCost',
               'text', 'setname', 'power', 'toughness']]
             .sort_values(by=['name', 'multiverseId'], ascending=False)
             .groupby('name')
             .first())
    cards.index = cards.index.str.lower()
    cards = cards[cards.type != 'Scheme']
    cards.loc[:, 'mytext'] = (cards.manaCost.fillna('{0}')
                              + ' '
                              + cards.type
                              + ((' ' + cards.power + '/' + cards.toughness).fillna(''))
                              + ': '
                              + cards.text.str.replace('\s+', ' ').fillna(''))
    cards.mytext = cards.mytext.str.lower().str.replace('["\']', '')
    return cards

In [8]:
F_CORPUS = 'mtg-corpus.txt'

def get_corpus(f=F_CORPUS):
    get_cards().mytext.to_csv(f, header=False, index=False, sep='\t')

In [9]:
if TRAIN_TOKENIZER:
    get_corpus()

In [10]:
if TRAIN_TOKENIZER:
    !head -n5 mtg-corpus.txt

{2}{r}{r}{g}{g} enchantment: at the beginning of your upkeep, you may say ach! hans, run! its the . . . and the name of a creature card. if you do, search your library for a card with that name, put it onto the battlefield, then shuffle your library. that creature gains haste. exile it at the beginning of the next end step.
{2}{b} enchantment: {3}{b}, exile a permanent you control with a league of dastardly doom watermark: return a permanent card with a league of dastardly doom watermark from your graveyard to the battlefield.
{w}{u}{b}{r}{g} summon — legend: cannot be the target of spells or effects. world champion has power and toughness equal to the life total of target opponent. {0}: discard your hand to search your library for 1996 world champion and reveal it to all players. shuffle your library and put 1996 world champion on top of it. use this ability only at the beginning of your upkeep, and only if 1996 world champion is in your library.
{4}{w}{b} enchantment: spells and a

## train a tokenizer

I found the `BERT` special tokens by looking at the defaults for

```py
tokenizers.BertWordPieceTokenizer.train?
```

In [12]:
%%time

MODEL_NAME = 'mtg-language'

# if you run with vocab_size = 300_000 (default), you get 11,761 vocab items
# any number larger than that will include null tokens and will also include
# single words. we will choose a number that is *just* below that here
VOCAB_SIZE = 11_500

if TRAIN_TOKENIZER:
    #trainable_tokenizer = tokenizers.ByteLevelBPETokenizer(lowercase=True)
    trainable_tokenizer = tokenizers.BertWordPieceTokenizer(lowercase=True)
    
    trainable_tokenizer.train(files=F_CORPUS,
                              vocab_size=VOCAB_SIZE,
                              special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])
    
    !mkdir -p {MODEL_NAME}

    trainable_tokenizer.save_model(MODEL_NAME)

CPU times: user 5.76 s, sys: 426 ms, total: 6.19 s
Wall time: 2.29 s


In [13]:
if TRAIN_TOKENIZER:
    !tail -n40 {MODEL_NAME}/vocab.txt

reliquary
mediocrity
petravark
petrahydrox
scrapheap
lichenthrope
ertais
geargrabber
unspeakable
sluggishness
hollowhenge
belzenlok
cerulean
vizkopa
vitaspore
stingmoggie
miscreant
misfortune
saltskitter
saltcrusted
comeuppance
collaborator
eddytrail
prognostic
polyraptor
morselhoarder
hemorrhage
melancholy
pummeler
tromokratis
uncontested
thraximundar
astrolabe
bangchuckers
epochrasite
gurzigost
kaleidostone
moondrakes
nucklavee
svyelunite


In [14]:
if TRAIN_TOKENIZER:
    !wc -l {MODEL_NAME}/vocab.txt

   11400 mtg-language/vocab.txt


## make a training dataset

we will do pre-training of a completely un-initialized model down below, so we will need a dataset to train on to do that. `bert` has two objectives -- a masked language model (mlm) and a next sentence prediction.

in order to do this, we will need to create a dataset with the following features:

+ standard bert tokenizer outputs
    + `input_ids`
    + `attentention_mask`
    + `token_type_ids`
+ `labels`: _optional_, these are basically the same thing as `input_ids`, but allow for ignoring certain tokens for the purpose of loss calculations
+ `next_sentence_label`: these are the `0, 1` values indicating whether or not sentences A and B are continuations (in the original model) or `edhrec` pairs (our model)

more is better here, of course; I think the goal has to be full coverage of all cards and all edhrec pair-ups. to that end, I will create a dataset builder of a generator type.

+ [writing a dataset loading script walkthrough here](https://huggingface.co/docs/datasets/add_dataset.html)
+ [code template example here](https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py)

we have a few steps:

1. create a train / test / val split of card names
1. create a train / test / val split of edhrec pairings
1. save the above to files we can move around (parquet is fine)
1. create a datasetloader object that can take the above and generate the full datasets (and I do mean full!)

### train / test / val split for card names

+ train and val are just splits on cards
+ test is a straight holdout of a few enitre sets to test generalizability when new sets drop

get test first, then do normal train/val split on the rest

In [15]:
DATA_DIR = 'mtg-language-data'
os.makedirs(DATA_DIR, exist_ok=True)

In [16]:
def get_test_sets(target_test_frac, seed=SEED):
    cards = get_cards()
    mtg_set_sizes = cards.setname.value_counts()
    mtg_sets = set(cards.setname.unique())

    test_sets = set()
    test_sets_size = 0
    test_set_target_size = target_test_frac * cards.shape[0]

    random.seed(seed)
    while test_sets_size < test_set_target_size:
        s = random.choice(list(mtg_sets))
        test_sets.add(s)
        mtg_sets.remove(s)
        test_sets_size += mtg_set_sizes[s]
        #print(f"test_sets {test_sets} contain {test_sets_size} cards")
    return test_sets

we want to create three datasets (`train`, `test`, `validation`) that are approximately 0.9, 0.05, 0.05. there are two ways we generate records for that dataset, though:

+ a fixed number (e.g. 100) of pairs per card
    + e.g. take card a and generate up to 100 true / false high / false low pairs
+ as many pairs per card as are allowed under the provided card split
    + e.g. combine every card with every other card and also provide a 0/1 label based on whether or not it was an edhrec pair

we want to create these *record-level* splits by making *card-level* splits first. these two different end results require two different splitting fractions for the cards:

+ for option 1 (fixed pairs per card), the number of *records* per *card set* is N * num_cards
+ for option 2 (*not* capped), the number of *records* per *card set* is roughly num_cards ^ 2

option 1 is easy (just do .9/.05/.05). option 2 is trickier. we need $\alpha + 2\beta = 1$, and also want $\dfrac{\alpha ^ 2}{\beta ^ 2} \approx \dfrac{0.9}{0.05}$, i.e. $\alpha ^ 2 \approx 18 \beta ^ 2$

together, this means

$$
\alpha = 1 - 2 \beta \\
\alpha ^ 2 = 1 - 4 \beta + 4 \beta ^ 2 \\
1 - 4 \beta + 4 \beta ^ 2 \approx 18 \beta ^ 2 \\
1 - 4 \beta - 14 \beta ^ 2 \approx 0 \\
$$

In [17]:
a, b, c = -14, -4, 1
beta = (-b - (b ** 2 - 4 * a * c) ** .5) / (2 * a)
beta

0.16018862050852034

In [18]:
if NUM_MAX_PAIRS is None:
    TARGET_TEST_FRAC = beta
    VAL_FRAC = beta
else:
    TARGET_TEST_FRAC = 0.05
    VAL_FRAC = 0.05

if MAKE_MLM_NSP_DATASET_INPUTS:
    cards = get_cards()
    
    test_sets = get_test_sets(TARGET_TEST_FRAC)
    is_test = cards.setname.isin(test_sets)
    test = cards[is_test]
    train_val = cards[~is_test]
    
    # we want val frac relative to the *entire* dataset, which means we need to
    # scale it up when part of that dataset has been claimed by the test set
    # in other words, we want n_val records, where n_val = val_frac * N_total
    # now that n_test have been removed,
    # 
    #   n_val = adj_val_frac * (N_total - n_test)
    #   adj_val_frac = n_val / (N_total - n_test)
    #   adj_val_frac = val_frac * N_total / (N_total - n_test)
    adj_val_frac = VAL_FRAC * cards.shape[0] / train_val.shape[0]
    
    train, val = train_test_split(train_val, test_size=adj_val_frac, random_state=SEED)

    print(f"num records train: {train.shape[0]}")
    print(f"num records test:  {test.shape[0]}")
    print(f"num records val:   {val.shape[0]}")
    
    train.to_parquet(os.path.join(DATA_DIR, "cards.train.parquet"), index=True)
    test.to_parquet(os.path.join(DATA_DIR, "cards.test.parquet"), index=True)
    val.to_parquet(os.path.join(DATA_DIR, "cards.validation.parquet"), index=True)

num records train: 14769
num records test:  3501
num records val:   3486


### train / test / val split for edhrec pairs

use the card splits just defined above to subset all edhrec pairings into separate groups

In [19]:
if MAKE_MLM_NSP_DATASET_INPUTS:
    import mtg.extract.edhrec

    edhrec_cards = (mtg.extract.edhrec.get_commanders_and_cards()
                    [['name', 'commander']])
    commanders = edhrec_cards.commander.unique()
    cmdr_cmdr_df = pd.DataFrame([[c, c] for c in commanders],
                                columns=['name', 'commander'])
    edhrec_cards = edhrec_cards.append(cmdr_cmdr_df)
    
    edhrec_cards.name = edhrec_cards.name.str.lower().str.replace('//', '/')
    edhrec_cards.commander = edhrec_cards.commander.str.lower().str.replace('//', '/')
    
    edhrec_train = edhrec_cards[edhrec_cards.name.isin(train.index)].copy()
    edhrec_test = edhrec_cards[edhrec_cards.name.isin(test.index)].copy()
    edhrec_val = edhrec_cards[edhrec_cards.name.isin(val.index)].copy()

    print(f"num records train: {edhrec_train.shape[0]}")
    print(f"num records test:  {edhrec_test.shape[0]}")
    print(f"num records val:   {edhrec_val.shape[0]}")
    
    edhrec_train.to_parquet(os.path.join(DATA_DIR, "edhrec.train.parquet"), index=False)
    edhrec_test.to_parquet(os.path.join(DATA_DIR, "edhrec.test.parquet"), index=False)
    edhrec_val.to_parquet(os.path.join(DATA_DIR, "edhrec.validation.parquet"), index=False)

num records train: 153021
num records test:  56685
num records val:   30442


In [20]:
edhrec_train.head(2) if MAKE_MLM_NSP_DATASET_INPUTS else None

Unnamed: 0,name,commander
0,forsaken monument,"kozilek, the great distortion"
1,war room,"kozilek, the great distortion"


In [21]:
edhrec_test.head(2) if MAKE_MLM_NSP_DATASET_INPUTS else None

Unnamed: 0,name,commander
7,hedron archive,"kozilek, the great distortion"
10,mind stone,"kozilek, the great distortion"


### the datasetbuilder

In [None]:
!ls -alh {DATA_DIR}

In [None]:
# # quick check on the max length of our sequences:
# # we can easily build a tokenizer and apply it to every sentence
# # directly; this will give us a max length for a single sentence
# # and then our dataset max length is approx double that.

# from transformers import BertTokenizerFast

# tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# z = tokenizer(cards.mytext.unique().tolist(), max_length=1024)

# import collections

# l = [len(_) for _ in z['input_ids']]
# c = collections.Counter(l)

# df = (pd.DataFrame([{'k': k, 'v': v} for (k, v) in c.items()])
#       .sort_values(by='k'))
# df.loc[:, 'cs'] = df.v.cumsum() / df.v.sum()
# print(f"max single sequence length: {df.k.max()}")
# # df.plot('k', 'cs')

# import numpy as np
# pairs = np.random.choice(l, size=(100_000, 2))
# import matplotlib.pyplot as plt
# plt.hist(pairs.sum(axis=1), bins=50);

moral of the story from the above: almost every card is <100 tokens, max is 185. almost every *pair* of sequences is under 175 total tokens. 200 is *extremely* conservative actually

In [None]:
MAX_SEQ_LENGTH = 200

if MAKE_MLM_NSP_DATASET:
    tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
    tokenizer_map_func = build_tokenizer_map_func(tokenizer, max_length=MAX_SEQ_LENGTH)

In [None]:
def add_labels(examples):
    return {'labels': examples['input_ids']}

In [None]:
if MAKE_MLM_NSP_DATASET:
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    if DO_CHUNKED_DATASET_BUILDING:
        commands = [
            f'python edhrec_dataset_chunk_proc.py -n {N_CHUNKS} -c {c} -d mtg-language-data -o mtg-mlm-nsp-dataset-chunks > log.{c}.txt 2>&1'
            for c in range(N_CHUNKS)
        ]

        print(f'executing {N_CHUNKS} commands in chunks of {N_PROCS} parallel commands')

        for cmd_grp in grouper(tqdm(commands), N_PROCS, ''):
            processes = [subprocess.Popen(cmd, shell=True) for cmd in cmd_grp]
            for p in processes:
                p.wait()

            # we just saved 5 chunks in two places -- the hf cache directory
            # and the local directory. to avoid running out of disk space, we
            # will wipe the cache directory after every proc group
            hf_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'datasets')
            if os.path.isdir(hf_cache_dir):
                os.rmdir(hf_cache_dir)

        base_dataset = datasets.DatasetDict({
            split: concatenate_datasets([load_from_disk(f)
                                         for f in glob.glob('mtg-mlm-nsp-dataset-chunks/*')])
            for split in SPLITS})
    else:
        base_dataset = datasets.load_dataset('edhrec_dataset.py',
                                             data_dir=DATA_DIR,
                                             num_max_pairs=NUM_MAX_PAIRS)

    dataset = (base_dataset
               .shuffle(seeds={split: SEED for split in SPLITS})
               .map(tokenizer_map_func,
                    batched=False,
                    num_proc=NUM_PROC)
               .map(add_labels,
                    batched=True,
                    num_proc=NUM_PROC));

In [None]:
dataset if MAKE_MLM_NSP_DATASET else None

In [None]:
dataset.shape if MAKE_MLM_NSP_DATASET else None

In [None]:
if SAVE_DATASET_TO_FILE:
    dataset.save_to_disk('mtg-mlm-nsp-dataset')

with the smaller dataset defined above,

+ ~~add the tokenizer~~
+ ~~build the trainer and train config~~
+ ~~do a train round with the smaller datasets~~
+ ~~any difference at all??~~

if it looks promising,

+ build a *real* dataset
    + ~~write chunked dataset processer to leverage multiple cpus~~
    + move to a **CPU** box -- not a GPU box. this dataset creation is done on the CPU.
    + set up cpu box
        + `ssh-keygen -t ed25519 -C "r.zach.lamberty@gmail.com"`
        + `cat ~/.ssh/id_ed25519`
        + add to github [here](https://github.com/settings/keys)
        + `git clone git@github.com:RZachLamberty/mtg-research.git`
        + `source activate ...`
        + `pip install datasets transformers`
        + `jupyter notebook --no-browser --ip=0.0.0.0`
    + locally
        + `ssh -NfL 9999:localhost:8888 rzlcpu`
        + https://localhost:9999
        + `scp -r mtg-language rzlcpu:~/mtg-research/bert-edh-pair-prediction/mtg-language`
        + `scp -r mtg-language-data rzlcpu:~/mtg-research/bert-edh-pair-prediction/mtg-language-data`
    + in the jupyter notebook, update the flags at the top
        + `True`:
            + `MAKE_MLM_NSP_DATASET`
            + `DO_CHUNKED_DATASET_BUILDING`
        + all others `False`
    + save this dataset and copy it down
+ train
    + move to GPU (upload the saved dataset
    + **add early stopping**
    + run it

## pre-training

we actually do care most about the nsp (next sentence prediction) task -- for us, that's the "is edhrec pair / isn't edhrec pair" concept. this means we *have* to do a `bert` model, because all of the other models dropped that task in favor of others.

In [None]:
if DO_PRE_TRAINING:
    config = BertConfig(vocab_size=VOCAB_SIZE)
    model = BertForPreTraining(config=config)
    print(f'{model.num_parameters():,}')

In [None]:
if DO_SMALL_PRETRAIN_TEST:
    TRAIN_BATCH = 8
    EVAL_BATCH = 8
    LOGGING_STEPS = 250
    train_dataset = dataset['train'].select(range(80))
    eval_dataset = dataset['validation'].select(range(80))
else:
    TRAIN_BATCH = 36
    EVAL_BATCH = 36
    LOGGING_STEPS = 250
    train_dataset = dataset['train']
    eval_dataset = dataset['validation']


if DO_PRE_TRAINING:
    output_dir = './mtg-language-results-v1'
    
    # # use this format to pick up from an aborted run
    # model = BertForPreTraining.from_pretrained(f'./{output_dir}/checkpoint-5750')

    training_args = TrainingArguments(
        output_dir=output_dir,                    # output directory
        num_train_epochs=1,                       # total # of training epochs
        per_device_train_batch_size=TRAIN_BATCH,  # batch size per device during training
        per_device_eval_batch_size=EVAL_BATCH,    # batch size for evaluation
        warmup_steps=500,                         # number of warmup steps for learning rate scheduler
        weight_decay=0.01,                        # strength of weight decay
        logging_dir='./logs',                     # directory for storing logs
        # my custom ones
        logging_steps=LOGGING_STEPS,
        overwrite_output_dir=True,
        evaluation_strategy='steps',
        logging_first_step=True,
        seed=1337,
        dataloader_drop_last=True,
        dataloader_num_workers=30,
        label_names=['labels', 'next_sentence_label'],
        load_best_model_at_end=True,
        save_total_limit=10,
    )

    trainer = Trainer(
        model=model,                  # the instantiated 🤗 Transformers model to be trained
        args=training_args,           # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        eval_dataset=eval_dataset,    # evaluation dataset
    )

In [None]:
if DO_PRE_TRAINING:
    trainer.train()

In [None]:
trainer.evaluate(train_dataset)

In [None]:
trainer.evaluate(eval_dataset)

In [None]:
trainer.save_model('mtg-language-test-small')

In [None]:
fill_mask = pipeline(
    'fill-mask',
    model='./mtg-language-test-small',
    tokenizer=tokenizer)

In [None]:
dataset['validation'][0]['text_a']

In [None]:
fill_mask('{3}{w}{b} legendary creature — vampire knight 4/4: vigilance, lifelink {t}, pay 7 life: destroy target nonland [MASK]. activate this ability only during your turn.')