<a href="https://colab.research.google.com/github/burn874/mtg/blob/main/collab_quick_start.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Init

In [None]:
!git clone https://github.com/RyanSaxe/mtg.git
%cd mtg
!pip install .
!wget https://17lands-public.s3.amazonaws.com/analysis_data/draft_data/draft-data.VOW.PremierDraft.tar.gz
!wget https://17lands-public.s3.amazonaws.com/analysis_data/game_data/game-data.VOW.PremierDraft.tar.gz
!tar -xf draft-data.VOW.PremierDraft.tar.gz
!tar -xf game-data.VOW.PremierDraft.tar.gz

#### Imports

In [None]:
import pickle
from mtg.ml.generator import DraftGenerator, create_train_and_val_gens
from mtg.ml.models import DraftBot
from mtg.ml.trainer import Trainer
from mtg.utils.display import draft_log_ai
from mtg.obj.expansion import VOW

#### Preprocessing

In [None]:
game_data = '/content/mtg/game_data_public.VOW.PremierDraft.csv'
draft_data = '/content/mtg/draft_data_public.VOW.PremierDraft.csv'
expansion_fname = 'expansion.pkl'

expansion = VOW(
    bo1 = game_data,
    draft = draft_data,
    ml_data = True,
)

with open(expansion_fname, "wb") as f:
    pickle.dump(expansion, f)

## Training
#### Generators

In [None]:
expansion_fname = "expansion.pkl"
batch_size = 32
train_p = 1.0
emb_dim = 128
num_encoder_heads = 8
num_decoder_heads = 8
pointwise_ffn_width = 512
num_encoder_layers = 1
num_decoder_layers = 1
emb_dropout = 0.0
transformer_dropout = 0.1
output_MLP = True
lr_warmup = 2000
emb_margin = 1.0
emb_lambda = 0.5
rare_lambda = 10
cmc_lambda = 0.1
epochs = 10
verbose = True
model_name = 'draft_model'

with open(expansion_fname, "rb") as f:
    expansion = pickle.load(f)


train_gen, val_gen = create_train_and_val_gens(
    expansion.draft,
    expansion.cards,
    train_p = train_p,
    id_col = "draft_id",
    train_batch_size = batch_size,
    generator = DraftGenerator,
    include_val = True,
)

#### Model

In [None]:
model = DraftBot(
    cards = train_gen.cards,
    card_data = expansion.card_data_for_ML[5:],
    emb_dim = emb_dim,
    t = expansion.draft["position"].max() + 1,
    num_encoder_heads = num_encoder_heads,
    num_decoder_heads = num_decoder_heads,
    pointwise_ffn_width = pointwise_ffn_width,
    num_encoder_layers = num_encoder_layers,
    num_decoder_layers = num_decoder_layers,
    emb_dropout = emb_dropout,
    memory_dropout = transformer_dropout,
    name = "DraftBot",
    output_MLP = output_MLP,
)

model.compile(
    learning_rate = {
        "warmup_steps": lr_warmup
    },
    margin = emb_margin,
    emb_lambda = emb_lambda,
    rare_lambda = rare_lambda,
    cmc_lambda = cmc_lambda,
    card_data = expansion.card_data_for_ML.iloc[5:-1],
)

#### Actual Start of Training

In [None]:
trainer = Trainer(model, generator=train_gen, val_generator=val_gen,)
trainer.train(
    epochs,
    print_keys = [
        "prediction_loss", "embedding_loss", "rare_loss", "cmc_loss"
    ],
    verbose = verbose,
)

In [None]:
# we run inference once before saving the model in order to serialize it with the right input parameters for inference
output_df, attention = draft_log_ai(
    "https://www.17lands.com/draft/79dcc54822204a20a88a0e68ec3f8564",
    model,
    return_attention = True,
    idx_to_name = model.idx_to_name,
    t = model.t,
    n_cards = model.n_cards,
)
model.save(model_name)