# RAG for cards

In [1]:
from langchain_ollama import ChatOllama

llm = ChatOllama(model="llama3.1")

In [2]:
import json

In [3]:
with open("../data/cards/FDN.json") as f:
    data = json.load(f)

In [4]:
data['data']['cards'][0].keys

<function dict.keys>

In [48]:
from langchain_community.document_loaders import JSONLoader

extraction_data = ['colorIdentity', 'colors', 'convertedManaCost', 'edhrecRank', 
 'keywords', 'legalities', 'manaCost', 'manaValue', 'name',
 'power', 'rarity', 
 'setCode', 'subtypes', 'supertypes', 'text', 
 'toughness', 'type', 'types']

extraction_schema = """
.data.cards[] | {
            colors: .colors,
            convertedManaCost: .convertedManaCost,
            keywords: .keywords,
            manaCost: .manaCost,
            name: .name,
            power: .power,
            rarity: .rarity,
            subtypes: .subtypes,
            supertypes: .supertypes,
            text: .text,
            toughness: .toughness,
            types: .types
        }
"""

loader = JSONLoader(
    file_path="../data/cards/FDN.json",
    jq_schema=extraction_schema,
    text_content=False,
)

In [49]:
docs = loader.load()

In [50]:
print(docs[0]), print(len(docs))

page_content='{"colors": [], "convertedManaCost": 7, "keywords": ["First strike", "Lifelink", "Menace", "Reach", "Trample", "Vigilance", "Ward"], "manaCost": "{7}", "name": "Sire of Seven Deaths", "power": "7", "rarity": "mythic", "subtypes": ["Eldrazi"], "supertypes": [], "text": "First strike, vigilance\nMenace, trample\nReach, lifelink\nWard\u2014Pay 7 life.", "toughness": "7", "types": ["Creature"]}' metadata={'source': '/home/giles/code/deep_mtg/data/cards/FDN.json', 'seq_num': 1}
730


(None, None)

In [51]:
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="snowflake-arctic-embed2")
vector_store = InMemoryVectorStore(embeddings)

In [52]:
vector_1 = embeddings.embed_query(docs[0].page_content)
vector_2 = embeddings.embed_query(docs[1].page_content)

assert len(vector_1) == len(vector_2)
print(f"Generated vectors of length {len(vector_1)}\n")
print(vector_1[:10])

Generated vectors of length 1024

[-0.022801735, 0.041194554, 0.017705947, 0.01626768, -0.011525806, -0.0617139, 0.037792053, -0.017419714, -0.022631949, -0.007076665]


In [53]:
len(docs[0].page_content)

392

In [56]:
from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(
    [("system",
      "You are an expert Magic: The Gathering player."
      "You are helping a new player to understand what various cards do from a high-level perspective."
      "When provided with a card, you should write a concise summary of the card."
      "You can assume that the player understands the basic rules of the game"
      "Include basic keyword terms like 'flying' or 'trample', but do not explain what they mean."
      "Do not include details of rarity or set information."
      "Rather than quantifying attributes of the card, instead use qualitative terms like 'strong' or 'weak' to describe the card."
      "Include the name of the card in the summary."
      "Include details of the card's role in that, strengths, and weaknesses."
      "Do not return anything other than the summary of the card."),
    ("user", "{card}")]
)

summary = llm.invoke(prompt.invoke({'card': docs[0].page_content})).content
summary

'Sire of Seven Deaths is a strong, flying creature with formidable power and toughness. Its unique ability to ward itself for 7 life points can be a game-changer in the right situations, giving it a significant survival advantage. However, its large converted mana cost makes it difficult to play early on, and its inability to generate card draw or protection for itself may leave it vulnerable to removal spells.'

In [55]:
len(summary)

413

In [25]:
ids = vector_store.add_documents(documents=docs)

In [30]:
vector_store.get_by_ids(['3fd34a5a-9832-4fc6-8f6a-d7384a64bd51'])

[Document(id='3fd34a5a-9832-4fc6-8f6a-d7384a64bd51', metadata={'source': '/home/giles/code/deep_mtg/data/cards/FDN.json', 'seq_num': 1}, page_content='{"colorIdentity": [], "colors": [], "convertedManaCost": 7, "edhrecRank": 8500, "keywords": ["First strike", "Lifelink", "Menace", "Reach", "Trample", "Vigilance", "Ward"], "legalities": {"alchemy": "Legal", "brawl": "Legal", "commander": "Legal", "duel": "Legal", "explorer": "Legal", "future": "Legal", "gladiator": "Legal", "historic": "Legal", "legacy": "Legal", "modern": "Legal", "oathbreaker": "Legal", "pioneer": "Legal", "standard": "Legal", "standardbrawl": "Legal", "timeless": "Legal", "vintage": "Legal"}, "manaCost": "{7}", "manaValue": 7, "name": "Sire of Seven Deaths", "power": "7", "rarity": "mythic", "setCode": "FDN", "subtypes": ["Eldrazi"], "supertypes": [], "text": "First strike, vigilance\\nMenace, trample\\nReach, lifelink\\nWard\\u2014Pay 7 life.", "toughness": "7", "type": "Creature \\u2014 Eldrazi", "types": ["Creatur

In [18]:
results = vector_store.similarity_search(
    "Manacost of omniscience"
)

print(results[0].page_content)

{"colorIdentity": ["U"], "colors": ["U"], "convertedManaCost": 10, "edhrecRank": 1123, "keywords": null, "legalities": {"alchemy": "Legal", "brawl": "Legal", "commander": "Legal", "duel": "Legal", "explorer": "Legal", "future": "Legal", "gladiator": "Legal", "historic": "Legal", "legacy": "Legal", "modern": "Legal", "oathbreaker": "Legal", "pioneer": "Legal", "standard": "Legal", "standardbrawl": "Legal", "timeless": "Legal", "vintage": "Legal"}, "manaCost": "{7}{U}{U}{U}", "manaValue": 10, "name": "Omniscience", "power": null, "rarity": "mythic", "setCode": "FDN", "subtypes": [], "supertypes": [], "text": "You may cast spells from your hand without paying their mana costs.", "toughness": null, "type": "Enchantment", "types": ["Enchantment"]}


In [19]:
results

[Document(id='7a195142-963b-465f-8284-43a6da22dc4d', metadata={'source': '/home/giles/code/deep_mtg/data/cards/FDN.json', 'seq_num': 161}, page_content='{"colorIdentity": ["U"], "colors": ["U"], "convertedManaCost": 10, "edhrecRank": 1123, "keywords": null, "legalities": {"alchemy": "Legal", "brawl": "Legal", "commander": "Legal", "duel": "Legal", "explorer": "Legal", "future": "Legal", "gladiator": "Legal", "historic": "Legal", "legacy": "Legal", "modern": "Legal", "oathbreaker": "Legal", "pioneer": "Legal", "standard": "Legal", "standardbrawl": "Legal", "timeless": "Legal", "vintage": "Legal"}, "manaCost": "{7}{U}{U}{U}", "manaValue": 10, "name": "Omniscience", "power": null, "rarity": "mythic", "setCode": "FDN", "subtypes": [], "supertypes": [], "text": "You may cast spells from your hand without paying their mana costs.", "toughness": null, "type": "Enchantment", "types": ["Enchantment"]}'),
 Document(id='b24d68f1-3457-4584-9d6c-6b42c5d9a030', metadata={'source': '/home/giles/code/

In [33]:
from pathlib import Path

from langchain_community.document_loaders import PyPDFLoader
from langchain_core.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_ollama import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.tools import BaseTool
from typing import Optional, Type

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field




In [34]:
class ScalableQuery(BaseModel):
    query: str = Field(description="Retrieval query")
    k: int = Field(description="Number of results to return")
    score_threshold: float = Field(default=0.0, description="Minimum similarity score to return. Float between zero and one.")

In [84]:
from tqdm import tqdm

class CardsRetriever(BaseTool):
    sets_path: Path
    llm: ChatOllama
    embeddings: OllamaEmbeddings
    summary_prompt: ChatPromptTemplate
    card_vector_store: None | InMemoryVectorStore = None
    extraction_schema: str = """
        .data.cards[] | {
            colors: .colors,
            convertedManaCost: .convertedManaCost,
            keywords: .keywords,
            manaCost: .manaCost,
            name: .name,
            power: .power,
            rarity: .rarity,
            subtypes: .subtypes,
            supertypes: .supertypes,
            text: .text,
            toughness: .toughness,
            types: .types
        }
    """
    recreate_storage: bool = False

    name: str = "SetsRetriever"
    description: str = "Provides relevant information for cards in Magic."
    args_schema: Type[BaseModel] = ScalableQuery

    def __init__(self, sets_path: Path, llm:ChatOllama, embeddings: OllamaEmbeddings, recreate_storage: bool = False):
        summary_prompt = ChatPromptTemplate.from_messages((
            [("system",
            "You are an expert Magic: The Gathering player."
            "You are helping a new player to understand what various cards do from a high-level perspective."
            "When provided with a card, you should write a concise summary of the card."
            "You can assume that the player understands the basic rules of the game"
            "Include basic keyword terms like 'flying' or 'trample', but do not explain what they mean."
            "Do not include details of rarity or set information."
            "Rather than quantifying attributes of the card, instead use qualitative terms like 'strong' or 'weak' to describe the card."
            "Include the name of the card in the summary."
            "Include details of the card's role in that, strengths, and weaknesses."
            "Include the mana colors of the card, along with the a qualitative description of the mana cost."
            "Do not return anything other than the summary of the card."),
            ("user", "{card}")]
        ))
        super().__init__(sets_path=sets_path, llm=llm, embeddings=embeddings, summary_prompt=summary_prompt, recreate_storage=recreate_storage)
        self.create_storage()

    def create_storage(self) -> None:
        if self.recreate_storage or not (self.sets_path/"cards.vec").exists():
            self.card_vector_store = InMemoryVectorStore(self.embeddings)
            for s in self.sets_path.glob("*.json"):
                print(f"Loading {s}...")
                loader = JSONLoader(s, self.extraction_schema, text_content=False)
                cards = loader.load()

                # Remove duplicates and basic lands
                filtered_cards = []
                filtered_hashes = []
                for card in cards:
                    # card_dict = json.loads(card.page_content)
                    # if card_dict["name"] in ["Plains", "Island", "Swamp", "Mountain", "Forest"]:
                    #     continue
                    if (h := hash(card.page_content)) not in filtered_hashes:
                        filtered_cards.append(card)
                        filtered_hashes.append(h)

                # Create card summaries
                print(f"Creating summaries for {len(filtered_cards)} cards...")
                for card in tqdm(filtered_cards):
                    summary = self.llm.invoke(self.summary_prompt.invoke({'card': card.page_content})).content

                    card.page_content = '{"summary": "' + summary + '", ' + card.page_content[1:]

                self.card_vector_store.add_documents(documents=filtered_cards)
                print(f"Loaded {len(filtered_cards)} cards from set {s}.")
            
            print(f"Dumping vectors to disk {self.sets_path/"cards.vec"}...")
            self.card_vector_store.dump(self.sets_path/"cards.vec")
        
        else:
            print(f"Loading vectors from disk {self.sets_path/'cards.vec'}...")
            self.card_vector_store = InMemoryVectorStore(self.embeddings).load(self.sets_path/"cards.vec", self.embeddings)

    def _run(
        self, query: str, k: int, score_threshold: float = 0.0, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> list[str]:
        """Retrieve information related to a query."""
        retrieved_cards = self.card_vector_store.similarity_search_with_score(query, k=k)
        if k <= 2:
            score_threshold = 0.0
        filtered_cards = [card[0].page_content for card in retrieved_cards if card[1] > score_threshold]
        return filtered_cards

In [85]:
retriever = CardsRetriever(sets_path=Path("../data/cards"), llm=llm, embeddings=embeddings)

Loading vectors from disk ../data/cards/cards.vec...


In [86]:
retriever.invoke({'query':"Sire of seven deaths", 'k':5, 'score_threshold':0.})

['{"summary": "Sire of Seven Deaths is a strong Eldrazi creature with a high mana cost that\'s considered expensive. It has flying and vigilance, making it a formidable attacker. Its presence on the battlefield allows you to pay 7 life in case you need to save yourself from its wrath, but this also means it can be vulnerable if your opponent finds a way to deal with its threat. Overall, this card is a high-risk, high-reward threat that can swing games in favor of its controller, especially if they have a strong removal spell or a way to protect their life total.", "colors": [], "convertedManaCost": 7, "keywords": ["First strike", "Lifelink", "Menace", "Reach", "Trample", "Vigilance", "Ward"], "manaCost": "{7}", "name": "Sire of Seven Deaths", "power": "7", "rarity": "mythic", "subtypes": ["Eldrazi"], "supertypes": [], "text": "First strike, vigilance\\nMenace, trample\\nReach, lifelink\\nWard\\u2014Pay 7 life.", "toughness": "7", "types": ["Creature"]}',
 '{"summary": "Nine-Lives Famil

In [174]:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent

memory = MemorySaver()
agent_executor = create_react_agent(llm, [retriever], checkpointer=memory)
config = {"configurable": {"thread_id": "test_cards_retriever"}}

In [175]:
input_message = (
    "What is the mana cost of Omniscience?"
)

for event in agent_executor.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    event["messages"][-1].pretty_print()


What is the mana cost of Omniscience?


Tool Calls:
  SetsRetriever (2118514c-4422-4287-b96c-fbb85e38b68f)
 Call ID: 2118514c-4422-4287-b96c-fbb85e38b68f
  Args:
    k: 1
    query: Omniscience
    score_threshold: 0
Name: SetsRetriever

["{\"colors\": [\"U\"], \"convertedManaCost\": 10, \"keywords\": null, \"manaCost\": \"{7}{U}{U}{U}\", \"manaValue\": 10, \"name\": \"Omniscience\", \"power\": null, \"rarity\": \"mythic\", \"subtypes\": [], \"supertypes\": [], \"text\": \"You may cast spells from your hand without paying their mana costs.\", \"toughness\": null, \"types\": [\"Enchantment\"]}"]

The mana cost of Omniscience is {7}{U}{U}{U}.


In [176]:
input_message = (
    "How many demons are there? You can retrieve as many cards as you want"
)

for event in agent_executor.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    event["messages"][-1].pretty_print()


How many demons are there? You can retrieve as many cards as you want
Tool Calls:
  SetsRetriever (e8e49fc2-f0cd-43ee-9318-57b3d8cbb0d5)
 Call ID: e8e49fc2-f0cd-43ee-9318-57b3d8cbb0d5
  Args:
    k: 100
    query: demons
    score_threshold: 0
Name: SetsRetriever

["{\"colors\": [\"B\"], \"convertedManaCost\": 3, \"keywords\": null, \"manaCost\": \"{2}{B}\", \"manaValue\": 3, \"name\": \"Infernal Vessel\", \"power\": \"2\", \"rarity\": \"uncommon\", \"subtypes\": [\"Human\", \"Cleric\"], \"supertypes\": [], \"text\": \"When this creature dies, if it wasn't a Demon, return it to the battlefield under its owner's control with two +1/+1 counters on it. It's a Demon in addition to its other types.\", \"toughness\": \"1\", \"types\": [\"Creature\"]}", "{\"colors\": [\"B\"], \"convertedManaCost\": 4, \"keywords\": [\"Flying\"], \"manaCost\": \"{2}{B}{B}\", \"manaValue\": 4, \"name\": \"Desecration Demon\", \"power\": \"6\", \"rarity\": \"rare\", \"subtypes\": [\"Demon\"], \"supertypes\": []