# mtg language model: commander to recommendation task

this is functionally the same as the neighboring notebook except this is pre-trained not on the "card to card" nsp task but a "commander: card" nsp task.

In [None]:
import csv
import functools
import glob
import os
import random
import shutil
import subprocess

import datasets
import numpy as np
import pandas as pd
import tokenizers
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import (BertConfig,
                          BertForPreTraining,
                          BertForNextSentencePrediction,
                          BertTokenizerFast,
                          DataCollatorForLanguageModeling,
                          EarlyStoppingCallback,
                          Trainer,
                          TrainingArguments,
                          pipeline, )

from utils import build_tokenizer_map_func, grouper

In [None]:
IS_LOCAL_LAPTOP = False
TRAIN_TOKENIZER = False
MAKE_MLM_NSP_DATASET_INPUTS = False
MAKE_MLM_NSP_DATASET = False
DO_CHUNKED_DATASET_BUILDING = False
SAVE_DATASET_TO_FILE = False
LOAD_DATASET_FROM_FILE = False
DO_PRE_TRAINING = False
DO_SMALL_PRETRAIN_TEST = False
CHECK_PRETRAINED_MODEL_RESULTS = False
MAKE_DECK_RECOMMENDATION_INPUTS = False
DO_DECK_RECOMMENDATIONS = False

# whether or not we cap the number of card pairs, or we choose or use
# all possible pairs
NUM_MAX_PAIRS = None
# NUM_MAX_PAIRS = 500
# NUM_MAX_PAIRS = 1

In [None]:
if TRAIN_TOKENIZER:
    assert IS_LOCAL_LAPTOP
    
if DO_CHUNKED_DATASET_BUILDING:
    assert NUM_MAX_PAIRS is not None, "this is like 2 TB of data"

if MAKE_MLM_NSP_DATASET_INPUTS:
    assert IS_LOCAL_LAPTOP

if MAKE_DECK_RECOMMENDATION_INPUTS:
    assert IS_LOCAL_LAPTOP

In [None]:
SPLITS = ['train', 'validation', 'test_cmdr', 'test_set']
SEED = 1337

# for the chunking of the dataset builder

N_PROCS = os.cpu_count()
N_CHUNKS = N_PROCS * 10

In [None]:
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', None)

## get raw text dataset

In [None]:
if IS_LOCAL_LAPTOP:
    import mtg.cards

In [None]:
@functools.lru_cache(None)
def get_cards():
    cards = (mtg.cards.cards_df()
             [['name', 'multiverseId', 'scryfallId', 'type', 'manaCost',
               'text', 'setname', 'power', 'toughness']]
             .sort_values(by=['name', 'multiverseId'], ascending=False)
             .groupby('name')
             .first())
    cards.index = cards.index.str.lower()
    cards = cards[cards.type != 'Scheme']
    cards.loc[:, 'mytext'] = (cards.manaCost.fillna('{0}')
                              + ' '
                              + cards.type
                              + ((' ' + cards.power + '/' + cards.toughness).fillna(''))
                              + ': '
                              + cards.text.str.replace('\s+', ' ').fillna(''))
    cards.mytext = cards.mytext.str.lower().str.replace('["\']', '')
    return cards

In [None]:
F_CORPUS = 'mtg-corpus.txt'

def get_corpus(f=F_CORPUS):
    get_cards().mytext.to_csv(f, header=False, index=False, sep='\t')

In [None]:
if TRAIN_TOKENIZER:
    get_corpus()

In [None]:
if TRAIN_TOKENIZER:
    !head -n5 mtg-corpus.txt

## train a tokenizer

I found the `BERT` special tokens by looking at the defaults for

```py
tokenizers.BertWordPieceTokenizer.train?
```

In [None]:
%%time

MODEL_NAME = 'mtg-language-v2'

# if you run with vocab_size = 300_000 (default), you get 11,761 vocab items
# any number larger than that will include null tokens and will also include
# single words. we will choose a number that is *just* below that here
VOCAB_SIZE = 11_500

if TRAIN_TOKENIZER:
    #trainable_tokenizer = tokenizers.ByteLevelBPETokenizer(lowercase=True)
    trainable_tokenizer = tokenizers.BertWordPieceTokenizer(lowercase=True)
    
    trainable_tokenizer.train(files=F_CORPUS,
                              vocab_size=VOCAB_SIZE,
                              special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])
    
    !mkdir -p {MODEL_NAME}

    trainable_tokenizer.save_model(MODEL_NAME)

In [None]:
if TRAIN_TOKENIZER:
    !tail -n40 {MODEL_NAME}/vocab.txt

In [None]:
if TRAIN_TOKENIZER:
    !wc -l {MODEL_NAME}/vocab.txt

## make a training dataset

we will do pre-training of a completely un-initialized model down below, so we will need a dataset to train on to do that. `bert` has two objectives -- a masked language model (mlm) and a next sentence prediction.

in order to do this, we will need to create a dataset with the following features:

+ standard bert tokenizer outputs
    + `input_ids`
    + `attentention_mask`
    + `token_type_ids`
+ `labels`: _optional_, these are basically the same thing as `input_ids`, but allow for ignoring certain tokens for the purpose of loss calculations
+ `next_sentence_label`: these are the `0, 1` values indicating whether or not sentences A and B are continuations (in the original model) or `edhrec` commander / card pairs (our model)

more is better here, of course; I think the goal has to be full coverage of all cards and all edhrec pair-ups. to that end, I will create a dataset builder of a generator type.

+ [writing a dataset loading script walkthrough here](https://huggingface.co/docs/datasets/add_dataset.html)
+ [code template example here](https://github.com/huggingface/datasets/blob/master/templates/new_dataset_script.py)

we have a few steps:

1. create a train / test / val split of card names
1. create a train / test / val split of edhrec pairings
1. save the above to files we can move around (parquet is fine)
1. create a datasetloader object that can take the above and generate the full datasets (and I do mean full!)

### train / test / val split

there are two types of generalization we care about:

+ how does it generalize to a completely new commander?
+ how does it generalize to a completely new set of recommendation cards?

we will create a pair of test sets instead of just one:

1. all recommendations from a number of held-out commanders
1. all recommendations for 2 held out mtg sets (`ZNR` and `IKO` are the most recent sets with json data as of 2020-01-17, so we will use those)

for train and validation, we will split based on set, as we did previously. this will give us a train and val set of cards, and we will apply that split to *both* commanders and recommended cards.

a given record is a pair of commander and card, so if all things are equal, we have four types of pairs. assuming our test holdout is equally likely to be a commander or a recommended card, and a holdout fraction of `alpha` would result in

| commander | card | frac |
|-|-|-|
| control | control | `(1 - alpha) ** 2` |
| test | control | `(1 - alpha) * alpha` |
| control | test | `(1 - alpha) * alpha` |
| test | test | `alpha ** 2` |

of these, only the first is in our train dataset. if we want the train dataset here to be 95% of all records, we need `(1 - alpha) ** 2 = 0.95`

In [None]:
alpha = 1 - 0.95 ** .5
alpha

In [None]:
DATA_DIR = 'mtg-language-cmdr-rec-data'
os.makedirs(DATA_DIR, exist_ok=True)

In [None]:
if MAKE_MLM_NSP_DATASET_INPUTS:
    import mtg.extract.edhrec

In [None]:
def get_test_commanders(target_test_frac, commanders=None, seed=SEED):
    if commanders is None:
        commanders = (mtg.extract.edhrec.get_commanders_and_cards()
                      .commander
                      .unique())
    np.random.seed(seed)
    return np.random.choice(commanders,
                            size=round(commanders.size * target_test_frac))

In [None]:
def clean_card_name(c):
    return c.lower().split(' // ')[0]

assert clean_card_name('Bruna, the Fading Light // Brisela, Voice of Nightmares') == 'bruna, the fading light'

In [None]:
TEST_SETNAMES = ['ZNR']
TEST_CMDR_FRAC = 0.03
VAL_FRAC = alpha


if MAKE_MLM_NSP_DATASET_INPUTS:
    edhrec_cards = (mtg.extract.edhrec.get_commanders_and_cards()
                    [['name', 'commander']])
    edhrec_cards.name = edhrec_cards.name.apply(clean_card_name)
    edhrec_cards.commander = edhrec_cards.commander.apply(clean_card_name)
    edhrec_cards.name = edhrec_cards.name.str.lower().str.replace('//', '/')
    edhrec_cards.commander = edhrec_cards.commander.str.lower().str.replace('//', '/')
    commanders = edhrec_cards.commander.unique()
    
    # make the two test sets
    
    # test commanders
    test_commanders = get_test_commanders(TEST_CMDR_FRAC, commanders)
    is_test_commander = edhrec_cards.commander.isin(test_commanders)
    test_cmdr = edhrec_cards[is_test_commander]
    
    # test set cards
    cards = get_cards()
    cardsets = (cards
                .reset_index()
                [['name', 'setname']])
    edhrec_cards = (edhrec_cards
                    # card set
                    .merge(cardsets, how='left', on='name')
                    .merge(cardsets.rename(columns={'setname': 'cmdr_setname',
                                                    'name': 'commander'}),
                           how='left', on='commander'))
    is_test_set = ((~is_test_commander)
                   & (edhrec_cards.setname.isin(TEST_SETNAMES)
                      | edhrec_cards.cmdr_setname.isin(TEST_SETNAMES)))
    test_set = edhrec_cards[is_test_set]

    # make the train and val sets. split all cards based on alpha, then
    # subset the non-test cards based on whether either side is in the val set
    is_test = is_test_commander | is_test_set
    train_val = edhrec_cards[~is_test]
    
    train_val_cards = cards[~cards.setname.isin(TEST_SETNAMES)].index.values
    val_cards = np.random.choice(train_val_cards,
                                 size=round(train_val_cards.size * VAL_FRAC))

    has_val_card = (train_val.name.isin(val_cards)
                    | train_val.commander.isin(val_cards))
    
    val = train_val[has_val_card]
    train = train_val[~has_val_card]

    print(f"num records train:       {train.shape[0]:,} ({train.shape[0] / edhrec_cards.shape[0]:.2%})")
    print(f"num records val:         {val.shape[0]:,} ({val.shape[0] / edhrec_cards.shape[0]:.2%})")
    print(f"num records test (cmdr): {test_cmdr.shape[0]:,} ({test_cmdr.shape[0] / edhrec_cards.shape[0]:.2%})")
    print(f"num records test (set):  {test_set.shape[0]:,} ({test_set.shape[0] / edhrec_cards.shape[0]:.2%})")
    
    # save card splits (will be used for negative labels)
    # this is a little tricky: we want train to just be train,
    # but all the others can be any card (val, test_cmdr, test_set)
    train_cards = (cards
                   [(~cards.setname.isin(TEST_SETNAMES))
                    & (~cards.index.isin(val_cards))])
    
    train_cards.to_parquet(os.path.join(DATA_DIR, 'cards.train.parquet'), index=True)
    for splitname in ['validation', 'test_cmdr', 'test_set']:
        cards.to_parquet(os.path.join(DATA_DIR, f'cards.{splitname}.parquet'), index=True)
    
    # save edhrecs
    train.to_parquet(os.path.join(DATA_DIR, "edhrec.train.parquet"), index=False)
    val.to_parquet(os.path.join(DATA_DIR, "edhrec.validation.parquet"), index=False)
    test_cmdr.to_parquet(os.path.join(DATA_DIR, "edhrec.test_cmdr.parquet"), index=False)
    test_set.to_parquet(os.path.join(DATA_DIR, "edhrec.test_set.parquet"), index=False)

### the datasetbuilder

In [None]:
!ls -alh {DATA_DIR}

In [None]:
# # quick check on the max length of our sequences:
# # we can easily build a tokenizer and apply it to every sentence
# # directly; this will give us a max length for a single sentence
# # and then our dataset max length is approx double that.

# from transformers import BertTokenizerFast

# tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# z = tokenizer(cards.mytext.unique().tolist(), max_length=1024)

# import collections

# l = [len(_) for _ in z['input_ids']]
# c = collections.Counter(l)

# df = (pd.DataFrame([{'k': k, 'v': v} for (k, v) in c.items()])
#       .sort_values(by='k'))
# df.loc[:, 'cs'] = df.v.cumsum() / df.v.sum()
# print(f"max single sequence length: {df.k.max()}")
# # df.plot('k', 'cs')

# import numpy as np
# pairs = np.random.choice(l, size=(100_000, 2))
# import matplotlib.pyplot as plt
# plt.hist(pairs.sum(axis=1), bins=50);

moral of the story from the above: almost every card is <100 tokens, max is 185. almost every *pair* of sequences is under 175 total tokens. 200 is *extremely* conservative actually

In [None]:
def make_chunked_dataset_shell_cmd(n_chunks, c, num_max_pairs, max_seq_length):
    nmp_flag = f' -m {num_max_pairs}' if num_max_pairs is not None else ''
    return (f'python edhrec_dataset_chunk_proc.py'
            f' -n {N_CHUNKS}'
            f' -c {c}'
            f' -d mtg-language-data-cmdr-rec-data'
            f' -o mtg-mlm-nsp-cmdr-rec-dataset-chunks'
            f'{nmp_flag}'
            f' -l {max_seq_length}'
            f' -p cmdr-rec'
            f' > logs/log.{c}.txt 2>&1')

In [None]:
# clean up the logs and outputs
if DO_CHUNKED_DATASET_BUILDING:
    shutil.rmtree('logs', ignore_errors=True)
    shutil.rmtree('mtg-mlm-nsp-cmdr-rec-dataset-chunks/', ignore_errors=True)

In [None]:
MAX_SEQ_LENGTH = 200

if MAKE_MLM_NSP_DATASET:
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    for split in SPLITS:
        n = pd.read_parquet(os.path.join(DATA_DIR, f'cards.{split}.parquet')).shape[0]
        print(f"num records {split}: {n}")
    
    if DO_CHUNKED_DATASET_BUILDING:
        if not os.path.isdir('logs'):
            os.makedirs('logs')

        commands = [make_chunked_dataset_shell_cmd(N_CHUNKS, c, NUM_MAX_PAIRS, MAX_SEQ_LENGTH)
                    for c in range(N_CHUNKS)]

        print(f'executing {N_CHUNKS} commands in chunks of {N_PROCS} parallel commands')

        for cmd_grp in grouper(tqdm(commands), N_PROCS, ''):
            processes = [subprocess.Popen(cmd, shell=True) for cmd in cmd_grp]
            for p in processes:
                p.wait()

            # we just saved 5 chunks in two places -- the hf cache directory
            # and the local directory. to avoid running out of disk space, we
            # will wipe the cache directory after every proc group
            hf_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'datasets')
            if os.path.isdir(hf_cache_dir):
                shutil.rmtree(hf_cache_dir)

        base_dataset = datasets.DatasetDict({
            split: datasets.concatenate_datasets(
                [datasets.load_from_disk(f)[split]
                 for f in glob.glob('mtg-mlm-nsp-cmdr-rec-dataset-chunks/*')])
            for split in SPLITS})
    else:
        tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
        tokenizer_map_func = build_tokenizer_map_func(tokenizer, max_length=MAX_SEQ_LENGTH)
        
        base_dataset = (datasets.load_dataset('edhrec_dataset.py',
                                              data_dir=DATA_DIR,
                                              num_max_pairs=NUM_MAX_PAIRS,
                                              pair_type='cmdr-rec')
                        .map(tokenizer_map_func, batched=False, num_proc=N_PROCS))

    dataset = (base_dataset
               .shuffle(seeds={split: SEED for split in SPLITS}))

In [None]:
base_dataset if MAKE_MLM_NSP_DATASET else None

In [None]:
base_dataset.shape if MAKE_MLM_NSP_DATASET else None

In [None]:
dataset if MAKE_MLM_NSP_DATASET else None

In [None]:
# dataset.shape if MAKE_MLM_NSP_DATASET else None

In [None]:
if SAVE_DATASET_TO_FILE:
    dataset.save_to_disk('mtg-mlm-nsp-cmdr-rec-dataset')

In [None]:
# %%sh
# aws s3 ls s3://mtg-research

In [None]:
if LOAD_DATASET_FROM_FILE:
    f = 'mtg-mlm-nsp-cmdr-rec-dataset-small' if DO_SMALL_PRETRAIN_TEST else 'mtg-mlm-nsp-cmdr-rec-dataset'
    dataset = datasets.load_from_disk(f)

In [None]:
dataset if LOAD_DATASET_FROM_FILE else None

In [None]:
# double check the next sentence labels!
import pandas as pd

if LOAD_DATASET_FROM_FILE:
    z = dataset['train'][:10]

(pd.DataFrame({k: z[k]
               for k in ['name_a', 'name_b', 'next_sentence_label', 'rec_set_type']})
 if LOAD_DATASET_FROM_FILE
 else None)

with the smaller dataset defined above,

+ ~~add the tokenizer~~
+ ~~build the trainer and train config~~
+ ~~do a train round with the smaller datasets~~
+ ~~any difference at all??~~

if it looks promising,

+ build a *real* dataset
    + ~~write chunked dataset processer to leverage multiple cpus~~
    + ~~move to a **CPU** box -- not a GPU box. this dataset creation is done on the CPU.~~
    + ~~set up cpu box~~
        + `ssh-keygen -t ed25519 -C "r.zach.lamberty@gmail.com"`
        + `cat ~/.ssh/id_ed25519.pub`
        + add to github [here](https://github.com/settings/keys)
        + `git clone git@github.com:RZachLamberty/mtg-research.git`
        + `source activate ...`
        + `pip install datasets transformers`
        + `jupyter notebook --no-browser --ip=0.0.0.0`
    + ~~locally~~
        + `ssh -NfL 9999:localhost:8888 rzlcpu`
        + https://localhost:9999
        + `scp -r mtg-language rzlcpu:~/mtg-research/bert-edh-pair-prediction/mtg-language`
        + `scp -r mtg-language-data rzlcpu:~/mtg-research/bert-edh-pair-prediction/mtg-language-data`
    + ~~in the jupyter notebook, update the flags at the top~~
        + `True`:
            + `MAKE_MLM_NSP_DATASET`
            + `DO_CHUNKED_DATASET_BUILDING`
            + `SAVE_DATASET_TO_FILE`
        + all others `False`
    + ~~save this dataset and copy it down~~
+ train
    + move to GPU
    + set up gpu box
        + try a `p2.8xlarge` instead of the p3
        + hard drive should be at least 500 gb
        + attach the read s3 iam role
        + `ssh-keygen -t ed25519 -C "r.zach.lamberty@gmail.com"`
        + `cat ~/.ssh/id_ed25519.pub`
        + add to github [here](https://github.com/settings/keys)
        + `git clone git@github.com:RZachLamberty/mtg-research.git`
        + in a new terminal, kick this off asap:
            + `aws s3 sync s3://mtg-research/bert-edh-pair-prediction/mtg-mlm-nsp-dataset mtg-mlm-nsp-dataset`
        + go back to the original terminal
        + `screen -S j`
        + `source activate ...`
        + `pip install datasets transformers`
        + `jupyter notebook --no-browser --ip=0.0.0.0`
        + `cd ~/mtg-research/bert-edh-pair-prediction`
    + run it

## pre-training

we actually do care most about the nsp (next sentence prediction) task -- for us, that's the "is edhrec pair / isn't edhrec pair" concept. this means we *have* to do a `bert` model, because all of the other models dropped that task in favor of others.

WE WERE HERE

In [None]:
if DO_PRE_TRAINING:
    config = BertConfig(vocab_size=VOCAB_SIZE)

    model = BertForPreTraining(config=config)
    # # use this format to pick up from an aborted run
    # model = BertForPreTraining.from_pretrained(f'./{output_dir}/checkpoint-5750')
    
    tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
    
    print(f'number of params in model: {model.num_parameters():,}')

In [None]:
if DO_PRE_TRAINING:
    if DO_SMALL_PRETRAIN_TEST:
        TRAIN_BATCH = 8
        EVAL_BATCH = 8
        LOGGING_STEPS = 250
        EARLY_STOPPING_PATIENCE = 8
        train_dataset = dataset['train'].select(range(80))
        eval_dataset = dataset['validation'].select(range(80))
        output_dir = './mtg-language-results-v2-small'
    else:
        TRAIN_BATCH = 51
        EVAL_BATCH = 186
        #LOGGING_STEPS = 50
        #EARLY_STOPPING_PATIENCE = 8
        LOGGING_STEPS = 100
        EARLY_STOPPING_PATIENCE = 5
        train_dataset = dataset['train']
        eval_dataset = dataset['validation']
        
        ## every 100th record, then cut down to an exact batch
        ## size multiple, caculated as:
        ##   # divmod(dataset['validation'].num_rows // 10, 186)[0] * 186
        #eval_dataset = (dataset['validation']
        #                .filter(lambda example, indice: indice % 10 == 0,
        #                        with_indices=True)
        #                .filter(lambda example, indice: indice < 2_232,
        #                        with_indices=True))
        #
        #assert eval_dataset.num_rows % EVAL_BATCH == 0
        #print(eval_dataset.num_rows)
        output_dir = './mtg-language-results-v2'
    
    # this callback just says that if there are 8 consecutive (measured every
    # eval_steps = ogging_steps steps) eval losses greater than the one at t0,
    # kill the job
    esc = EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE,
                                early_stopping_threshold=0.0)
    
    # this is responsible for adding the `labels`feature that we actually train on
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                    mlm=True,
                                                    mlm_probability=0.15)

    training_args = TrainingArguments(
        output_dir=output_dir,                    # output directory
        num_train_epochs=1,                       # total # of training epochs
        per_device_train_batch_size=TRAIN_BATCH,  # batch size per device during training
        per_device_eval_batch_size=EVAL_BATCH,    # batch size for evaluation
        warmup_steps=500,                         # number of warmup steps for learning rate scheduler
        weight_decay=0.01,                        # strength of weight decay
        logging_dir='./logs',                     # directory for storing logs
        # my custom ones
        logging_steps=LOGGING_STEPS,
        overwrite_output_dir=True,
        evaluation_strategy='steps',
        logging_first_step=True,
        seed=1337,
        dataloader_drop_last=True,
        dataloader_num_workers=30,
        #label_names=['labels', 'next_sentence_label'],
        load_best_model_at_end=True,
        save_total_limit=10,
    )

    trainer = Trainer(
        model=model,                  # the instantiated 🤗 Transformers model to be trained
        args=training_args,           # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        eval_dataset=eval_dataset,    # evaluation dataset
        callbacks=[esc],
        tokenizer=tokenizer,
        data_collator=data_collator
    )

In [None]:
if DO_PRE_TRAINING:
    trainer.train()

In [None]:
local_model_dir = 'mtg-language-v2-test-small' if DO_SMALL_PRETRAIN_TEST else MODEL_NAME

In [None]:
if DO_PRE_TRAINING:
    trainer.save_model(local_model_dir)

In [None]:
if DO_PRE_TRAINING:
    evals = {split: trainer.evaluate(dataset[split])
             for split in ['test_cmdr', 'test_set']
             if split != 'train'}

evals if DO_PRE_TRAINING else None

In [None]:
fill_mask = pipeline(
    'fill-mask',
    model=local_model_dir,
    tokenizer=tokenizer) if DO_PRE_TRAINING else None

In [None]:
# # this is to get the text string we mask below
# dataset['validation'][0]['text_a'] if DO_PRE_TRAINING else None

In [None]:
# easy: legendary [MASK] -- vampire knight
fill_mask('{3}{w}{b} legendary [MASK] — vampire knight 4/4: vigilance, lifelink {t}, pay 7 life: destroy target nonland permanent. activate this ability only during your turn.') if DO_PRE_TRAINING else None

In [None]:
# harder: destroy target nonland [MASK] (permanent)
fill_mask('{3}{w}{b} legendary creature — vampire knight 4/4: vigilance, lifelink {t}, pay 7 life: destroy target nonland [MASK]. activate this ability only during your turn.') if DO_PRE_TRAINING else None

## deck recommendations

the dataset here is different: for each commander card you care about, text a is the commander card and text b is every possible other card in history (lol)

In [None]:
commanders = ['merieke ri berit',
              'wort, boggart auntie',
              "kykar, wind's fury",
              'purphoros, bronze-blooded',
              'marchesa, the black rose',
              "trostani, selesnya's voice",
              'mizzix of the izmagnus',
              'oona, queen of the fae',
              'zada, hedron grinder',
              'breya, etherium shaper']

In [None]:
def commander_file_name(c):
    f_out = (name_a
                 .replace("'", '')
                 .replace(',', '')
                 .replace(' ', '-'))
    return f'{f_out}.csv'

In [None]:
if MAKE_DECK_RECOMMENDATION_INPUTS:
    cards = get_cards()
    
    for name_a in commanders:
        text_a = cards.loc[name_a].mytext
        
        df = (cards
              [cards.index != name_a]
              .copy()
              .reset_index()
              [['name', 'mytext']]
              .rename(columns={'name': 'name_b',
                               'mytest': 'text_b'}))
        df.loc[:, 'name_a'] = name_a
        df.loc[:, 'text_a'] = text_a
        
        f_out = commander_file_name(name_a)
        print(f_out)
        df.to_csv(f_out, index=False, quoting=csv.QUOTE_ALL)

In [None]:
if DO_DECK_RECOMMENDATIONS:
    EVAL_BATCH = 186
    
    # loading the trained model
    config = BertConfig.from_pretrained(MODEL_NAME)
    model = BertForNextSentencePrediction.from_pretrained(MODEL_NAME, config=config)
    tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
    
    training_args = TrainingArguments(
        output_dir='./ignore',
        per_device_eval_batch_size=EVAL_BATCH
    )

    trainer = Trainer(model=model, args=training_args, )

In [None]:
if DO_DECK_RECOMMENDATIONS:
    tokenizer_map_func = build_tokenizer_map_func(tokenizer, max_length=MAX_SEQ_LENGTH)
    
    ds_to_check = (datasets.load_dataset('csv',
                                         data_files={k: commander_file_name(k)
                                                     for k in deck_names},
                                         quoting=csv.QUOTE_ALL)
                   .map(tokenizer_map_func, batched=True))

ds_to_check if DO_DECK_RECOMMENDATIONS else None

In [None]:
from scipy.special import softmax

if DO_DECK_RECOMMENDATIONS:
    for deck_name in deck_names:
        print(f"deck_name = {deck_name}")
        p = trainer.predict(ds_to_check[deck_name])
        print(f"p.predictions.shape = {p.predictions.shape}")

        probs = softmax(p.predictions, axis=1)

        z = pd.DataFrame({'p0': probs[:, 0],
                          'p1': probs[:, 1],
                          'y_pred': probs.argmax(axis=1),
                          'name_b': ds_to_check[deck_name]['name_b'],
                          'text_b': ds_to_check[deck_name]['text_b']})
        z.reset_index(drop=True, inplace=True)

        recs = (z
                .groupby(['name_b', 'text_b'])
                .p0
                .median()
                .sort_values(ascending=False)
                .reset_index())

        recs.to_parquet(f"{deck_name}.mlm-v2.parquet")