In [2]:
%reload_ext autoreload
%autoreload 2

# Setup colored logging

In [1]:
import logging

import colorlog

handler = colorlog.StreamHandler()
fmt = "%(log_color)s%(levelname)s:%(name)s:%(message)s"
formatter = colorlog.ColoredFormatter(
    fmt,
    log_colors={
        "DEBUG": "purple",
        "INFO": "green",
        "WARNING": "yellow",
        "ERROR": "red",
        "CRITICAL": "red,bg_white",
    },
)
handler.setFormatter(formatter)
logging.basicConfig(level=logging.INFO, handlers=[handler])
logger = logging.getLogger("mtg-ai")
logger.propagate = True


import pandas as pd
pd.set_option('display.max_columns', 500)
pd.options.display.max_columns = 500
pd.options.display.max_rows = 500
pd.options.display.max_colwidth = 500

# Card Dataframe

In [None]:
from mtg_ai.cards import MTGDatabase
from IPython.display import display
database = MTGDatabase()

print(database.df.layout.value_counts())

display(database.df.loc[database.df.name == "Bruna, the Fading Light // Brisela, Voice of Nightmares"])
display(database.df.loc[database.df.name == "Gisela, the Broken Blade // Brisela, Voice of Nightmares"])
display(database.df.loc[database.df.name == "Brisela, Voice of Nightmares"])


display(database.df.loc[database.df.name.str.contains("Ajani")])

# Dataset Builder

In [None]:
from mtg_ai.cards.training_data_builder import build_datasets

build_datasets(all_merged=True)


In [None]:
from mtg_ai.cards import MTGDatasetLoader

MTGDatasetLoader.load_dataset("all")

# Training

### Train on all joined

In [None]:
from mtg_ai.ai import ModelAndTokenizer, MTGCardAITrainerPipeline
from mtg_ai.cards.training_data_builder import QUESTION_ANSWER_FOLDER
model = ModelAndTokenizer.UNSLOTH_LLAMA_3_2_3B_INSTRUCT_Q8

training_pipeline = MTGCardAITrainerPipeline(base_model_name=model.value, dataset_directory=QUESTION_ANSWER_FOLDER)
training_pipeline.train("all", resume_from_checkpoint=False)

### Train on cards

In [None]:
from mtg_ai.ai import ModelAndTokenizer, MTGCardAITrainerPipeline
from mtg_ai.cards.training_data_builder import QUESTION_ANSWER_FOLDER
model = ModelAndTokenizer.UNSLOTH_LLAMA_3_2_3B_INSTRUCT_Q8

training_pipeline = MTGCardAITrainerPipeline(base_model_name=model.value, dataset_directory=QUESTION_ANSWER_FOLDER)
training_pipeline.train("cards", resume_from_checkpoint=True)

### Train on rules

In [None]:
from mtg_ai.ai import ModelAndTokenizer, MTGCardAITrainerPipeline
from mtg_ai.cards.training_data_builder import QUESTION_ANSWER_FOLDER

model = ModelAndTokenizer.UNSLOTH_LLAMA_3_2_3B_INSTRUCT_Q8
training_pipeline = MTGCardAITrainerPipeline(base_model_name=model.value, dataset_directory=QUESTION_ANSWER_FOLDER)
training_pipeline.train("rules", resume_from_checkpoint=False, learning_rate=1e-7, weight_decay=2e-8)

### Train on combos

In [None]:
from mtg_ai.ai import ModelAndTokenizer, MTGCardAITrainerPipeline
from mtg_ai.cards.training_data_builder import QUESTION_ANSWER_FOLDER

model = ModelAndTokenizer.UNSLOTH_LLAMA_3_2_3B_INSTRUCT_Q8
training_pipeline = MTGCardAITrainerPipeline(base_model_name=model.value, dataset_directory=QUESTION_ANSWER_FOLDER)
training_pipeline.train("combos", resume_from_checkpoint=False, learning_rate=1e-7, weight_decay=2e-8, train_batch_size=16, eval_batch_size=8, gradient_accumulation_steps=1)

# Combine Lora with model

In [None]:
from pathlib import Path

from mtg_ai.ai import ModelAndTokenizer, save_combined_model

model_name = ModelAndTokenizer.BARTOWSKI_LLAMA_3_2_8B_INSTRUCT_Q4_K_L
model_dir = Path("./results").resolve()

save_combined_model(
    model_name=model_name.value,
    tokenizer_name=model_name.tokenizer,
    gguf_file=model_name.gguf_file,
    model_dir=model_dir,
)

# Inference

In [None]:
%env TQDM_DISABLE = 1
%env HAYSTACK_PROGRESS_BARS = 0

import tqdm
from pathlib import Path
from functools import partialmethod

import tqdm.auto
# handler = colorlog.StreamHandler()
# fmt = "%(log_color)s%(levelname)s:%(name)s:%(message)s"
# formatter = colorlog.ColoredFormatter(
#     fmt,
#     log_colors={
#         "DEBUG": "purple",
#         "INFO": "green",
#         "WARNING": "yellow",
#         "ERROR": "red",
#         "CRITICAL": "red,bg_white",
#     },
# )
# handler.setFormatter(formatter)
# logging.basicConfig(level=logging.DEBUG, handlers=[handler])

from mtg_ai.utils import is_tqdm_disabled
print(is_tqdm_disabled())
from mtg_ai.ai import MTGAIRunner

ai_model_name: str = "./results/combos"
rag_embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
runner = MTGAIRunner(
    ai_model_name=ai_model_name, rag_embedding_model_name=rag_embedding_model_name
)

test_questions = [
    "What is the converted mana cost of Acquisitions Expert?",
    "What is the converted mana cost of Loch Korrigan?",
    "What is the type of Loch Korrigan?",
    "What is the text of Loch Korrigan?",
    "What is the text of Tarmogoyf?",
    "What is the type of Ajani, Nacatl Pariah",
    "What is the text of Ajani, Nacatl Pariah",
    "What is the cmc of Ajani, Nacatl Pariah",
    "What is the text of Ajani, Nacatl Avenger",
    "What is the type of Ajani, Nacatl Avenger",
    "What is a Reversible cards work?",
    "What can I combo with Ajani, Nacatl Avenger?",
]
max_new_tokens = 500
top_k = 3

for tq in test_questions:
    print()
    print(tq)
    for word in runner.run(tq, max_new_tokens=500, filters=None, top_k=3, stream_output=True, temperature=0.3):
        if "\\n" in word:
            print(word, flush=True)
        else:
            print(word, end="", flush=True)
    print()

# Experimenting 

In [None]:
from mtg_ai.cards.edh_combos import EDHComboDatabase

edh_combos = EDHComboDatabase()

In [2]:
import os
os.getenv("TQDM_MININTERVAL")

In [None]:
from mtg_ai.cards.edh_combos import EDHComboDatabase
from mtg_ai.cards import MTGDatabase
from mtg_ai.cards.training_data_builder import DataEntry
from IPython.display import display
edh_combos = EDHComboDatabase()

database = MTGDatabase()

zone_locations_to_text = {
    "B": "on the battlefield",
    "G": "in the graveyard",
    "H": "in your hand",
    "L": "in the library",
    "E": "exiled",
    "C": "in the command zone",
}

def build_cards_to_combo_question_answer_dataset(database: MTGDatabase, edh_combos: EDHComboDatabase):
    result: list[DataEntry] = []
    for combo in edh_combos:
        
        card_names_text = ", ".join(combo["cards"]["card_name"].to_list())
        
        features = []
        for _, feature_name in combo["features"]["feature_name"].items():
            if "LTB" in feature_name:
                feature_name = feature_name.replace("LTB", "leaves the battlefield")
            elif "ETB" in feature_name:
                feature_name = feature_name.replace("ETB", "enters the battlefield")
            features.append(f"  - {feature_name}")
        features_text = "\n".join(features)
            
        
        steps = []
        for i, step in  enumerate(combo["combo"]["steps"].splitlines()):
            steps.append(f"  {i+1}. {step}")
        steps_text = "\n".join(steps)
        
        question = f"How can you create a combo with {card_names_text}?"
        answer = (
            f"This combo can be formed with {card_names_text}\n\n"
            f"Color identity: {combo['combo']['identity']}\n"
            ""
            f"Mana cost: {combo['combo']['manaNeeded']}\n"
            ""
            "Steps:\n"
            f"{steps_text}"
            "\n\n"
            "Result:\n"
            f"{features_text}"
        )
        
        additional_prerequisites = []
        if len(combo["cards"]["zone_locations"].unique()) == 1:
            zones = combo["cards"]["zone_locations"].unique().tolist()
            zone_text = zone_locations_to_text[zones[0][0]]
            text = f"  - All permanants must be {zone_text}"
            additional_prerequisites.append(text)
        else:
            for _, card in combo["cards"].iterrows():
                zones = card["zone_locations"]
                if len(zones) == 1:
                    zone_text = zone_locations_to_text[zones[0]]
                    text = f"  - {card['card_name']} must be {zone_text}"
                    additional_prerequisites.append(text)
                else:
                    zone_text = " or ".join([zone_locations_to_text[zone] for zone in zones])
                    text = f"  - {card['card_name']} must be {zone_text}"
                    additional_prerequisites.append(text)
        
        prerequisites = additional_prerequisites
        if other_prerequisites := combo["combo"]["otherPrerequisites"]:
            other_prerequisites = other_prerequisites or ""
            for prerequisite in other_prerequisites.splitlines():
                prerequisites.append(f"  - {prerequisite}")
        
        if prerequisites:
            other_prerequisites_text = "\n".join(prerequisites)
            answer += f"\n\nOther prerequisites:\n{other_prerequisites_text}"
        
        result.append(DataEntry(question=question, answer=answer))
    return result
    
result = build_cards_to_combo_question_answer_dataset(database, edh_combos)

In [None]:
len(result)

In [None]:
combo = edh_combos.get_combo("647-1069-1256-5499")

combo["cards"]["zone_locations"].unique()[0][0]


In [None]:
v = edh_combos.get_combo("647-1069-1256-5499")
print(v["combo"]["otherPrerequisites"])
print()
for i, step in  enumerate(v["combo"]["steps"].splitlines()):
    print(f"{i+1}. {step}")

In [None]:
edh_combos.