In [1]:
%reload_ext mtg_ai
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from mtg_ai.data import MTGCards
from mtg_ai import constants
from mtg_ai.training import MTGCardTraining

In [None]:
mtg_card_ai = MTGCardTraining()


turning rows into strings: 18136it [00:00, 58812.99it/s]
tokenizing rows: 100%|██████████████████████████████████████████████████████████| 18136/18136 [00:10<00:00, 1735.73it/s]


Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/5.06G [00:00<?, ?B/s]

In [None]:
mtg_card_ai.train()

In [None]:
import pandas as pd
from tqdm import tqdm
from pathlib import Path


In [None]:
def remove_uneeded_columns(df: pd.DataFrame):
    df.rename(columns={"name": "card_name"}, inplace=True)
    df = df.drop(constants.drop_columns, axis=1)
    return df

def filter_for_modern(df: pd.DataFrame):
    legalities = pd.json_normalize(df.pop("legalities"))
    df = df.loc[legalities["modern"] == "legal"].reset_index(drop=True)
    return df

def fill_empty_values(df: pd.DataFrame):
    fill_values = dict.fromkeys(["mana_cost", "colors", "color_identity", "produced_mana", "color_indicator"], constants.NA_STRING)
    fill_values.update(dict.fromkeys(["power", "toughness", "loyalty", ], constants.NAN_STRING))
    fill_values["edhrec_rank"] = 0
    df.fillna(fill_values, inplace=True)
    return df

def merge_lists(df: pd.DataFrame):
    df.keywords = df.keywords.str.join(", ")
    columns = ["colors", "color_identity", "color_indicator", "produced_mana"]
    df[columns] = df[columns].map(lambda x: "".join(x))
    return df

def sort_color_strings(df: pd.DataFrame):
    columns = ["colors", "color_identity", "color_indicator", "produced_mana"]
    df[columns] = df[columns].map(constants.MTGColorCombo._sort_multicolor_str)
    return df

def convert_column_types(df: pd.DataFrame):
    data = {
        "oracle_id": str,
        "card_name": str,
        "rarity": constants.CATEGORY,
        "mana_cost": constants.CATEGORY,
        "cmc": float,
        "colors": constants.CATEGORY,
        "color_identity": constants.CATEGORY,
        "type_line": str,
        "power": str,
        "toughness": str,
        "loyalty": constants.CATEGORY,
        "produced_mana": constants.CATEGORY,
        "set_type": constants.CATEGORY,
        "oracle_text": str,
        "layout": constants.CATEGORY,
        "edhrec_rank": int,
        "color_indicator": constants.CATEGORY,
    }
    df = df.astype(data)
    return df

mtg_data_path = Path("./data/oracle-cards-20231121100139.json")
df: pd.DataFrame = (
    pd.read_json(mtg_data_path)
    .pipe(remove_uneeded_columns)
    .pipe(filter_for_modern)
    .pipe(fill_empty_values)
    .pipe(merge_lists)
    .pipe(sort_color_strings)
    .pipe(convert_column_types)
)

In [None]:
df

In [None]:
df.loc[df.layout == "flip"]

In [None]:
mtg_data_path = Path("./data/oracle-cards-20231121100139.json")
df = pd.read_json(mtg_data_path)
df.color_identity.loc[df.color_identity.notna()]

In [None]:
df.colors.str.join("")

In [None]:
df.cmc.loc[df.cmc > 10]

In [None]:
training = MTGCardTraining(num_epochs=4)

In [None]:
training.train()