# mtg language model

let's follow along with [this notebook](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb) to do a training from scratch of our MTG corpus

In [None]:
import functools
import math
import os
import random

import datasets
import pandas as pd
import tokenizers
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

In [None]:
IS_LOCAL_LAPTOP = True
TRAIN_TOKENIZER = True
MAKE_MLM_NSP_DATASET_INPUTS = True
MAKE_MLM_NSP_DATASET = True
DO_PRE_TRAINING = True
DO_SMALL_PRETRAIN_TEST = True

# whether or not we cap the number of card pairs, or we choose or use
# all possible pairs
# NUM_MAX_PAIRS = None
NUM_MAX_PAIRS = 1

In [None]:
if TRAIN_TOKENIZER:
    assert IS_LOCAL_LAPTOP

In [None]:
SPLITS = ['train', 'test', 'validation']
SEED = 1337
NUM_PROC = 4

In [None]:
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', None)

## get raw text dataset

In [None]:
if IS_LOCAL_LAPTOP:
    import mtg.cards

In [None]:
@functools.lru_cache(None)
def get_cards():
    cards = (mtg.cards.cards_df()
             [['name', 'multiverseId', 'scryfallId', 'type', 'manaCost',
               'text', 'setname', 'power', 'toughness']]
             .sort_values(by=['name', 'multiverseId'], ascending=False)
             .groupby('name')
             .first())
    cards.index = cards.index.str.lower()
    cards = cards[cards.type != 'Scheme']
    cards.loc[:, 'mytext'] = (cards.manaCost.fillna('{0}')
                              + ' '
                              + cards.type
                              + ((' ' + cards.power + '/' + cards.toughness).fillna(''))
                              + ': '
                              + cards.text.str.replace('\s+', ' ').fillna(''))
    cards.mytext = cards.mytext.str.lower().str.replace('["\']', '')
    return cards

In [None]:
F_CORPUS = 'mtg-corpus.txt'

def get_corpus(f=F_CORPUS):
    get_cards().mytext.to_csv(f, header=False, index=False, sep='\t')

In [None]:
if TRAIN_TOKENIZER:
    get_corpus()

In [None]:
if TRAIN_TOKENIZER:
    !head -n5 mtg-corpus.txt

## train a tokenizer

I found the `BERT` special tokens by looking at the defaults for

```py
tokenizers.BertWordPieceTokenizer.train?
```

In [None]:
import tokenizers
tokenizers.__version__

In [None]:
%%time

MODEL_NAME = 'mtg-language'

# if you run with vocab_size = 300_000 (default), you get 11,761 vocab items
# any number larger than that will include null tokens and will also include
# single words. we will choose a number that is *just* below that here
VOCAB_SIZE = 11_500

if TRAIN_TOKENIZER:
    #trainable_tokenizer = tokenizers.ByteLevelBPETokenizer(lowercase=True)
    trainable_tokenizer = tokenizers.BertWordPieceTokenizer(lowercase=True)
    
    trainable_tokenizer.train(files=F_CORPUS,
                              vocab_size=VOCAB_SIZE,
                              special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])
    
    !mkdir -p {MODEL_NAME}

    trainable_tokenizer.save_model(MODEL_NAME)

In [None]:
if TRAIN_TOKENIZER:
    !tail -n40 {MODEL_NAME}/vocab.txt

In [None]:
if TRAIN_TOKENIZER:
    !wc -l {MODEL_NAME}/vocab.txt

## make a training dataset

we will do pre-training of a completely un-initialized model down below, so we will need a dataset to train on to do that. `bert` has two objectives -- a masked language model (mlm) and a next sentence prediction.

in order to do this, we will need to create a dataset with the following features:

+ standard bert tokenizer outputs
    + `input_ids`
    + `attentention_mask`
    + `token_type_ids`
+ `labels`: _optional_, these are basically the same thing as `input_ids`, but allow for ignoring certain tokens for the purpose of loss calculations
+ `next_sentence_label`: these are the `0, 1` values indicating whether or not sentences A and B are continuations (in the original model) or `edhrec` pairs (our model)

more is better here, of course; I think the goal has to be full coverage of all cards and all edhrec pair-ups. to that end, I will create a dataset builder of a generator type.

+ [writing a dataset loading script walkthrough here](https://huggingface.co/docs/datasets/add_dataset.html)
+ [code template example here](https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py)

we have a few steps:

1. create a train / test / val split of card names
1. create a train / test / val split of edhrec pairings
1. save the above to files we can move around (parquet is fine)
1. create a datasetloader object that can take the above and generate the full datasets (and I do mean full!)

### train / test / val split for card names

+ train and val are just splits on cards
+ test is a straight holdout of a few enitre sets to test generalizability when new sets drop

get test first, then do normal train/val split on the rest

In [None]:
DATA_DIR = 'mtg-language-data'
os.makedirs(DATA_DIR, exist_ok=True)

In [None]:
def get_test_sets(target_test_frac, seed=SEED):
    cards = get_cards()
    mtg_set_sizes = cards.setname.value_counts()
    mtg_sets = set(cards.setname.unique())

    test_sets = set()
    test_sets_size = 0
    test_set_target_size = target_test_frac * cards.shape[0]

    random.seed(seed)
    while test_sets_size < test_set_target_size:
        s = random.choice(list(mtg_sets))
        test_sets.add(s)
        mtg_sets.remove(s)
        test_sets_size += mtg_set_sizes[s]
        #print(f"test_sets {test_sets} contain {test_sets_size} cards")
    return test_sets

In [None]:
TARGET_TEST_FRAC = 0.05
VAL_FRAC = 0.05

# the number of records is basically #cards ^ 2,
# so if we ant 0.05% of all *sentence pairs* to be val, we need 0.05 ^ (1/2)% of *cards*
# (which is about 22%)

if MAKE_MLM_NSP_DATASET_INPUTS:
    cards = get_cards()
    
    # if things are capped per card, number of cards is all that matters
    # if they are not, then number of cards squared determines the set size
    tf = TARGET_TEST_FRAC if (NUM_MAX_PAIRS is not None) else TARGET_TEST_FRAC ** .5
    vf = VAL_FRAC if (NUM_MAX_PAIRS is not None) else VAL_FRAC ** .5
    
    test_sets = get_test_sets(tf)
    is_test = cards.setname.isin(test_sets)
    test = cards[is_test]
    train_val = cards[~is_test]
    
    adj_val_frac = vf / (1 - tf)
    if NUM_MAX_PAIRS is None:
        adj_val_frac = adj_val_frac ** .5
    
    train, val = train_test_split(train_val, test_size=adj_val_frac, random_state=SEED)

    print(f"num records train: {train.shape[0]}")
    print(f"num records test:  {test.shape[0]}")
    print(f"num records val:   {val.shape[0]}")
    
    train.to_parquet(os.path.join(DATA_DIR, "cards.train.parquet"), index=True)
    test.to_parquet(os.path.join(DATA_DIR, "cards.test.parquet"), index=True)
    val.to_parquet(os.path.join(DATA_DIR, "cards.validation.parquet"), index=True)

### train / test / val split for edhrec pairs

use the card splits just defined above to subset all edhrec pairings into separate groups

In [None]:
if MAKE_MLM_NSP_DATASET_INPUTS:
    import mtg.extract.edhrec

    edhrec_cards = (mtg.extract.edhrec.get_commanders_and_cards()
                    [['name', 'commander']])
    commanders = edhrec_cards.commander.unique()
    cmdr_cmdr_df = pd.DataFrame([[c, c] for c in commanders],
                                columns=['name', 'commander'])
    edhrec_cards = edhrec_cards.append(cmdr_cmdr_df)
    
    edhrec_cards.name = edhrec_cards.name.str.lower().str.replace('//', '/')
    edhrec_cards.commander = edhrec_cards.commander.str.lower().str.replace('//', '/')
    
    edhrec_train = edhrec_cards[edhrec_cards.name.isin(train.index)].copy()
    edhrec_test = edhrec_cards[edhrec_cards.name.isin(test.index)].copy()
    edhrec_val = edhrec_cards[edhrec_cards.name.isin(val.index)].copy()

    print(f"num records train: {edhrec_train.shape[0]}")
    print(f"num records test:  {edhrec_test.shape[0]}")
    print(f"num records val:   {edhrec_val.shape[0]}")
    
    edhrec_train.to_parquet(os.path.join(DATA_DIR, "edhrec.train.parquet"), index=False)
    edhrec_test.to_parquet(os.path.join(DATA_DIR, "edhrec.test.parquet"), index=False)
    edhrec_val.to_parquet(os.path.join(DATA_DIR, "edhrec.validation.parquet"), index=False)

In [None]:
edhrec_train.head(2) if MAKE_MLM_NSP_DATASET_INPUTS else None

In [None]:
edhrec_test.head(2) if MAKE_MLM_NSP_DATASET_INPUTS else None

### the datasetbuilder

In [None]:
!ls -alh {DATA_DIR}

In [None]:
# # quick check on the max length of our sequences:
# # we can easily build a tokenizer and apply it to every sentence
# # directly; this will give us a max length for a single sentence
# # and then our dataset max length is approx double that.

# from transformers import BertTokenizerFast

# tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# z = tokenizer(cards.mytext.unique().tolist(), max_length=1024)

# import collections

# l = [len(_) for _ in z['input_ids']]
# c = collections.Counter(l)

# df = (pd.DataFrame([{'k': k, 'v': v} for (k, v) in c.items()])
#       .sort_values(by='k'))
# df.loc[:, 'cs'] = df.v.cumsum() / df.v.sum()
# print(f"max single sequence length: {df.k.max()}")
# # df.plot('k', 'cs')

# import numpy as np
# pairs = np.random.choice(l, size=(100_000, 2))
# import matplotlib.pyplot as plt
# plt.hist(pairs.sum(axis=1), bins=50);

moral of the story from the above: almost every card is <100 tokens, max is 185. almost every *pair* of sequences is under 175 total tokens. 200 is *extremely* conservative actually

In [None]:
from transformers import BertTokenizerFast

from utils import build_tokenizer_map_func

MAX_SEQ_LENGTH = 200

if MAKE_MLM_NSP_DATASET:
    tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
    tokenizer_map_func = build_tokenizer_map_func(tokenizer, max_length=MAX_SEQ_LENGTH)

In [None]:
def add_labels(examples):
    return {'labels': examples['input_ids']}

In [None]:
if MAKE_MLM_NSP_DATASET:
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    dataset = (datasets.load_dataset('edhrec_dataset.py', data_dir=DATA_DIR, num_max_pairs=NUM_MAX_PAIRS)
               .shuffle(seeds={split: SEED for split in SPLITS})
               .map(tokenizer_map_func,
                    batched=False,
                    num_proc=NUM_PROC)
               .map(add_labels,
                    batched=True,
                    num_proc=NUM_PROC));

In [None]:
dataset if MAKE_MLM_NSP_DATASET else None

In [None]:
dataset.shape if MAKE_MLM_NSP_DATASET else None

with the smaller dataset defined above,

+ ~~add the tokenizer~~
+ ~~build the trainer and train config~~
+ do a train round with the smaller datasets
+ any difference at all??

if it looks promising,

+ build a *real* dataset
    + update the 5 --> 100 or even None
    + move to a **CPU** box -- not a GPU box. this dataset creation is done on the CPU.
    + save this dataset and copy it down
+ train
    + move to GPU (upload the saved dataset
    + **add early stopping**
    + run it

## pre-training

we actually do care most about the nsp (next sentence prediction) task -- for us, that's the "is edhrec pair / isn't edhrec pair" concept. this means we *have* to do a `bert` model, because all of the other models dropped that task in favor of others.

In [None]:
from transformers import (BertConfig,
                          BertForPreTraining,
                          BertTokenizerFast,
                          Trainer,
                          TrainingArguments, )

In [None]:
if DO_PRE_TRAINING:
    config = BertConfig(vocab_size=VOCAB_SIZE)
    model = BertForPreTraining(config=config)
    print(f'{model.num_parameters():,}')

In [None]:
if DO_SMALL_PRETRAIN_TEST:
    TRAIN_BATCH = 8
    EVAL_BATCH = 8
    LOGGING_STEPS = 250
    train_dataset = dataset['train'].select(range(80))
    eval_dataset = dataset['validation'].select(range(80))
else:
    TRAIN_BATCH = 36
    EVAL_BATCH = 36
    LOGGING_STEPS = 250
    train_dataset = dataset['train']
    eval_dataset = dataset['validation']


if DO_PRE_TRAINING:
    output_dir = './mtg-language-results-v1'
    
    # # use this format to pick up from an aborted run
    # model = BertForPreTraining.from_pretrained(f'./{output_dir}/checkpoint-5750')

    training_args = TrainingArguments(
        output_dir=output_dir,                    # output directory
        num_train_epochs=1,                       # total # of training epochs
        per_device_train_batch_size=TRAIN_BATCH,  # batch size per device during training
        per_device_eval_batch_size=EVAL_BATCH,    # batch size for evaluation
        warmup_steps=500,                         # number of warmup steps for learning rate scheduler
        weight_decay=0.01,                        # strength of weight decay
        logging_dir='./logs',                     # directory for storing logs
        # my custom ones
        logging_steps=LOGGING_STEPS,
        overwrite_output_dir=True,
        evaluation_strategy='steps',
        logging_first_step=True,
        seed=1337,
        dataloader_drop_last=True,
        dataloader_num_workers=30,
        label_names=['labels', 'next_sentence_label'],
        load_best_model_at_end=True,
        save_total_limit=10,
    )

    trainer = Trainer(
        model=model,                  # the instantiated 🤗 Transformers model to be trained
        args=training_args,           # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        eval_dataset=eval_dataset,    # evaluation dataset
    )

In [None]:
if DO_PRE_TRAINING:
    trainer.train()

In [None]:
trainer.evaluate(train_dataset)

In [None]:
trainer.evaluate(eval_dataset)

In [None]:
trainer.save_model('mtg-language-test-small')

In [None]:
from transformers import pipeline

fill_mask = pipeline(
    'fill-mask',
    model='./mtg-language-test-small',
    tokenizer=tokenizer)

In [None]:
dataset['validation'][0]['text_a']

In [None]:
fill_mask('{3}{w}{b} legendary creature — vampire knight 4/4: vigilance, lifelink {t}, pay 7 life: destroy target nonland [MASK]. activate this ability only during your turn.')