# EDH card pair prediction

1. build train / test / dev data set
    1. get EDH card pair recommendations from edhrec.com. these have prediction value 1
    1. generate false pairs (prediction value 0) by randomly generating pairs
    1. split, stratifying on card color identity, card type, rarity.
    1. convert cards into sentences
1. fine-tune
    1. load pre-trained bert model on prediction task "card a, card b --> {yes,no} was edh rec
1. make deck predictions for one of my existing decks

In [None]:
# !pip install datasets transformers

In [None]:
import csv
import itertools
import os
from glob import glob

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset, load_from_disk
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import (BertConfig, BertTokenizerFast,
                          BertForNextSentencePrediction,
                          DataCollatorWithPadding,
                          PreTrainedModel, PreTrainedTokenizerFast,
                          Trainer, TrainingArguments, )

In [None]:
%matplotlib inline

## build train / test / dev data set

### get EDH card pair recommendations from edhrec.com

these will have prediction value 1

if we just ran with this, how many total pairs could we generate this way? basically, for every card in deck X, every other card is a valid pair. that's:

at first I was going to say no way, buuuuuut it's actually not terrible... we want big data, after all

we would need to generate about 32 min negative labels if that were the dataset we were interested in

### get all cards from mtgjson

to generate false pairs we will randomly select from all cards. about 65% of all MTG cards are referenced on edhrec, but the rest are also, presumably, good choices for 0 labels

we can eventually use this dataframe to create a generator of true card pairs off of a single card anchor

### split, stratifying on card color identity, card type, rarity.

we will split on cards. this is actually tricky, right? it would be easy if we could just do a 95/5/5 and then there was enough pairing between 5s and other 5s to build an entire test / val set, but I actually suspect we might have a problem fielding that many extra records. oh well, I guess we'll tell in due time

since we want to stratify on so many things, and we have a 2/3s chance of any card being in the true label, I actually think fully random sampling is approporiate. we can look at the breakdown of that by other features if we need to

### convert cards into sentences

let's just go with this, see how it works out

### create a `huggingface` `datasets`

following along with the relatively simple example [here](https://github.com/huggingface/datasets/blob/master/datasets/squad/squad.py)

#### custom dataset loader?

meh let's try the `csv` loader first

#### `csv` loader

generate `csv`s the same way we were doing `parquet` (see appendix) and load those as datasets

#### loading csvs, shuffling, tokenizing, etc datasets now

+ tokenizing from [here](https://huggingface.co/docs/datasets/processing.html#processing-data-in-batches)

## fine-tune

## double-checking our trained model

next steps

+ what do our false positives look like
+ what is the separation like for "cards that have been on edhrec" vs. "cards that havent
    + i.e. do we just predict "both cards have been on EDHREC"?
    + did we create a dataset that is just (edhrec cards, either type)? I thought we were making (either type, either type)
+ what is the sorted list of recommendations given an existing deck

why are these all only edhrec cards? I thought I was generating pairs from both sides?

is this a problem? when a new card shows up and has never been seen before, will the model be unable to handle it? I think not, because presumably there were cards in test / val that it had never seen before (have I verified that).

In [None]:
# loading the trained model
config = BertConfig.from_pretrained('edhrec-bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('edhrec-bert-base-uncased', config=config)
tokenizer = BertTokenizerFast.from_pretrained('edhrec-bert-base-uncased')

In [None]:
EVAL_BATCH = 36

training_args = TrainingArguments(
    output_dir='./ignore',
    per_device_eval_batch_size=EVAL_BATCH,    # batch size for evaluation
    label_names=['labels'],
)

trainer = Trainer(model=model, args=training_args, )

In [None]:
MAX_LENGTH = 300

# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

def tokenizer_map_func(rec):
    return tokenizer(rec['text_a'], rec['text_b'],
                     padding='max_length',
                     max_length=MAX_LENGTH,
                     truncation=True)


def fix_label(rec):
    return {'label_as_int': [int(_) for _ in rec['label']]}


split_types = ['val', 'test', 'train']
dataset = (load_dataset('csv',
                        data_files={split_type: sorted(glob(os.path.join('.', 'data', split_type, '*.csv')))
                                    for split_type in split_types},
                        quoting=csv.QUOTE_ALL)
           #.map(fix_label,
           #     batched=True)
           .shuffle(seeds={split_type: 1337
                           for split_type in split_types})
           .map(tokenizer_map_func,
                batched=True))

TODO

+ make the combo dataframe
+ convert that into a dataset (probably a `.from_pandas` or some shit)
+ pass that to eval
+ sort by predictions
+ profit

run the following on a computer with `mtg` installed

In [None]:
kykar_cards = [
    "Aetherflux Reservoir",
    "Anointed Procession",
    "As Foretold",
    "Austere Command",
    "Baral, Chief of Compliance",
    "Blue Sun's Zenith",
    "Boros Charm",
    "Chaos Warp",
    "Counterspell",
    "Cultivator's Caravan",
    "Cyclonic Rift",
    "Desolate Lighthouse",
    "Disallow",
    "Dismantling Blow",
    "Docent of Perfection",
    "Dovin's Veto",
    "Eerie Interlude",
    "Fact or Fiction",
    "Faithless Looting",
    "Gitaxian Probe",
    "Guttersnipe",
    "Impulse",
    "Kykar, Wind's Fury",
    "Mentor of the Meek",
    "Mizzix of the Izmagnus",
    "Mizzix's Mastery",
    "Murmuring Mystic",
    "Mystic Confluence",
    "Mystic Speculation",
    "Mystical Tutor",
    "Narset Transcendent",
    "Neurok Stealthsuit",
    "Niv-Mizzet, Parun",
    "Omniscience",
    "Ponder",
    "Preordain",
    "Primal Amulet",
    "Ral, Izzet Viceroy",
    "Reliquary Tower",
    "Render Silent",
    "Rhystic Study",
    "Serum Visions",
    "Sram's Expertise",
    "Stroke of Genius",
    "Sunforger",
    "Supreme Verdict",
    "Swords to Plowshares",
    "Taigam, Ojutai Master",
    "Talrand, Sky Summoner",
    "Teferi, Hero of Dominaria",
    "Teferi, Time Raveler",
    "The Locust God",
    "Thought Vessel",
    "Tidespout Tyrant",
    "Trail of Evidence",
    "Vandalblast",
    "Young Pyromancer", 
]

In [None]:
sets_to_check = [
    "2XM",
    "AKR",
    "C20",
    "CC1",
    "CMC",
    "CMR",
    "IKO",
    "JMP",
    "KHC",
    "M21",
    "MB1",
    "MH2",
    "Q03",
    "SLD",
    "SLU",
    "SS3",
    "THB",
    "TSR",
    "ZNC",
    "ZNE",
    "ZNR",
]

In [None]:
cards_to_check = (cards
                  [cards.setname.isin(sets_to_check)
                   & (~cards.rarity.isin(['common', 'uncommon']))
                   & cards.colorIdentity.apply(lambda x: set(x).difference(['W', 'U', 'R']) == set())
                   & ~cards.index.isin(kykar_cards)]
                  .index
                  .unique())
cards_to_check.shape

In [None]:
card_text.loc[kykar_cards[0], 'text']

In [None]:
df_to_check = pd.DataFrame([{'text_a': card_text.loc[kc, 'text'],
                             'text_b': card_text.loc[ctc, 'text'],
                             'name_a': kc,
                             'name_b': ctc}
                            for kc in kykar_cards
                            for ctc in cards_to_check])

df_to_check.head()

In [None]:
df_to_check.to_csv(os.path.join('.', 'kykar.csv'),
                   index=False,
                   quoting=csv.QUOTE_ALL)

now run the following on any machine that has `kykar.csv` copied to it

In [None]:
ds_to_check = (load_dataset('csv',
                            data_files='kykar.csv',
                            quoting=csv.QUOTE_ALL)
               .map(tokenizer_map_func, batched=True))

In [None]:
p = trainer.predict(ds_to_check)

p.predictions.shape

In [None]:
from scipy.special import softmax

probs = softmax(p.predictions, axis=1)

z = pd.DataFrame({'p1': probs[:, 1],
                  'y_pred': probs.argmax(axis=1),
                  'y': p.label_ids})
z.loc[:, 'is_right'] = z.y == z.y_pred

z.sort_values(by='p1', ascending=False, inplace=True)
z.reset_index(drop=True, inplace=True)

z.is_right.value_counts()

In [None]:
total_true = z.y.sum()

(z.y.cumsum() / total_true).plot()

In [None]:
z.head()

In [None]:
d1k = dataset['test'].select(range(1_000))
d1k

In [None]:
d1h = dataset['test'].select(range(100))
d1h

In [None]:
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', None)

In [None]:
import pandas as pd
import torch


def get_preds(n=100):
    chunk_size = 100
    z = None
    
    i = 0
    while i < n:
        print(f"i = {i}")
        d = dataset['test'][i: i + chunk_size]
        p = (model(**{k: torch.as_tensor(np.array(v))
                  for (k, v) in d.items()
                  if k in ['attention_mask', 'input_ids', 'token_type_ids']})
         [0]
         .softmax(1))

        z_now = pd.DataFrame(p.detach().numpy(), columns=['p0', 'p1'])
        for key in ['label', 'text_a', 'text_b']:
            z_now.loc[:, key] = d[key]
        
        if z is None:
            z = z_now
        else:
            z = z.append(z_now, ignore_index=True)
        
        i += chunk_size
        
    z.reset_index(drop=True, inplace=True)
    z.loc[:, 'p_delta'] = (z.p0 - z.p1).abs()
    
    z.loc[:, 'is_right'] = (z.p1 > z.p0) == z.label

    z.sort_values(by='p_delta', inplace=True, ascending=False)
    
    return z

In [None]:
z = get_preds(500)

z.tail(20)

In [None]:
z.is_right.value_counts()

In [None]:
z[~z.is_right].head(15)

In [None]:
is_edhrec = card_text.copy()
is_edhrec.loc[:, 'is_edhrec'] = is_edhrec.index.isin(edhrec_cards.name.unique())
is_edhrec.reset_index(inplace=True)
is_edhrec.head(10)

In [None]:
(z
 .merge(is_edhrec.rename(columns={'text': 'text_a', 'is_edhrec': 'is_edhrec_a'})[['text_a', 'is_edhrec_a']],
        how='left',
        on='text_a')
 .merge(is_edhrec.rename(columns={'text': 'text_b', 'is_edhrec': 'is_edhrec_b'})[['text_b', 'is_edhrec_b']],
        how='left',
        on='text_b')
 .groupby(['is_right', 'is_edhrec_b'])
 .is_right.count())

## make deck predictions for one of my existing decks

# appendix

the following is either hacking, didn't work, etc

### tokenizing sentences

~~we will be reusing most of the text sentences above several times; might as well tokenize them all up front once instead of tokenizing most of them 100x later~~

just do shit the way the documenation suggests we should. do them on the completely built pair parquet files below

In [None]:
# from transformers import RobertaTokenizerFast
# tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

In [None]:
# def my_tokenizer(row, *args, **kwargs):
#     return pd.Series(tokenizer(row.text, *args, **kwargs))

In [None]:
# (card_text.head(20)
#  .apply(my_tokenizer, axis=1, truncation=True, padding=True))

In [None]:
# card_text = (card_text
#              .join(card_text
#                    .apply(my_tokenizer, axis=1, truncation=True, padding=True)))

# card_text.head(10)

### making the pair suggestions dataset

okay so we have

1. a train / test / val split of all cards
1. a series of card text values (our "sentences")
1. a list of `card --> deck` relationships

the task now is to

1. generate positive and negative cases for each card
    + positive: `card --> deck <-- card`
    + negative: just not that
1. look up their text values
1. write those values to file
    + probably want to chunk this up somehow, maybe write 1k sentences per parquet

### build the pytorch datasets

basing this in large part off of [this doc page](https://huggingface.co/transformers/custom_datasets.html#nlplib)

#### do the encodings

so, the below killed the kernel... :(