In [1]:
import tensorflow as tf
import tensorflow_text as tf_text
import tensorflow_hub as tf_hub
from tqdm import tqdm

import chromadb

2024-03-23 14:16:37.559827: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
model = tf_hub.load("./model")

2024-03-23 14:16:43.222513: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [3]:
def get_embedings(texts: list[str]) -> list[list[float]]:
    return model(texts).numpy().tolist()

In [4]:
with open("./data/rules.txt", "r") as f:
    rules = f.readlines()


In [5]:
rule_category_dict = {}
begin_rules = 0
for idx, line in enumerate(rules):
    if line.strip() == "---":
        begin_rules = idx
        break
    num, category, *_ = line.strip().split(". ")
    category = category + ". ".join(_)
    rule_category_dict[int(num)] = category


In [6]:
def is_rule_identifier(word: str) -> bool:
    # check if first word is just digits, "." and maximum 1 char (eg. 102.4a)
    num_non_digit_chars = 0
    for char in word:
        if not char in "1234567890.":
            num_non_digit_chars += 1
    
    return num_non_digit_chars <= 1

In [7]:
rules_dict = {}
rule_identifier = ""
rule_text = ""
begin_glossary = 0
for idx, line in enumerate(rules[begin_rules + 1:]):
    idx += begin_rules + 1

    if line.strip() == "---G":
        begin_glossary = idx
        break

    if line.strip().endswith("See rule") or line.strip().endswith("See rules"):
        print(idx + 1, line.strip())

    #100.1. These Magic rules apply to any Magic game with two or more players, including two-player\ngames and multiplayer games.
    first_word = line.split(" ")[0].strip()
    if is_rule_identifier(first_word):
        if rule_identifier:
            rules_dict[rule_identifier] = rule_text
        
        # new rule
        rule_identifier = first_word
        rule_text = ""
    
    rule_text += line.strip() + " "


In [8]:
client = chromadb.PersistentClient("./data/chromadb")
collection = client.get_or_create_collection("rules")

In [9]:
rule_identifiers = list(rules_dict.keys())
print(len(rule_identifiers), len(list(set(rule_identifiers))))

2691 2691


In [10]:
max_batch_size = 128
batch = []
for idx, (rule_identifier, rule_text) in tqdm(enumerate(rules_dict.items()), total=len(rules_dict)):
    batch.append({
        "rule_identifier": rule_identifier,
        "rule_text": rule_text,
        "id": idx
    })

    if len(batch) >= max_batch_size:
        embedings = get_embedings([rule["rule_text"] for rule in batch])
        for idx, rule in enumerate(batch):
            rule["embeding"] = embedings[idx]
        
        collection.add(
            ids= [rule["rule_identifier"] for rule in batch],
            embeddings= [rule["embeding"] for rule in batch],
            documents= [rule["rule_text"] for rule in batch],
        )

        batch = []

#---

embedings = get_embedings([rule["rule_text"] for rule in batch])
for idx, rule in enumerate(batch):
    rule["embeding"] = embedings[idx]

collection.add(
    ids= [rule["rule_identifier"] for rule in batch],
    embeddings= [rule["embeding"] for rule in batch],
    documents= [rule["rule_text"] for rule in batch],
)

batch = []

 14%|█▍        | 384/2691 [00:18<01:44, 22.03it/s]2024-03-23 14:17:10.159623: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 807469056 exceeds 10% of free system memory.
2024-03-23 14:17:12.011845: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 807469056 exceeds 10% of free system memory.
2024-03-23 14:17:14.009290: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 807469056 exceeds 10% of free system memory.
2024-03-23 14:17:15.751829: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 807469056 exceeds 10% of free system memory.
2024-03-23 14:17:18.061999: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 807469056 exceeds 10% of free system memory.
 19%|█▉        | 512/2691 [00:30<02:16, 15.92it/s]INFO:backoff:Backing off send_request(...) for 0.5s (requests.exceptions.ReadTimeout: HTTPSConnectionPool(host='us-api.i.posthog.com', port=443): Read timed out

In [14]:
def search(texts: list[str] | str) -> list[dict]:
    if isinstance(texts, str):
        texts = [texts]
    
    embeddings = get_embedings(texts)

    r = collection.query(
        query_embeddings=embeddings
    )
    #{ids: list[list[ids]], distances: list[list[distances]], documents: list[list[documents]]}

    r_out = []

    for i in range(len(texts)):
        ids = r["ids"][i]
        distances = r["distances"][i]
        documents = r["documents"][i]

        r_arr = []
        for j in range(len(ids)):
            r_arr.append(
                {
                    "id": ids[j],
                    "distance": distances[j],
                    "text": documents[j]
                }
            )
        
        r_out.append(r_arr)
    
    return r_out

In [26]:
r = search(["triggered ability"])

for s in r:
    for res in s:
        for key, value in res.items():
            print(key, value)
        print("-")
    print("=")


id 702.13a
distance 1.0540034770965576
text 702.13a Intimidate is an evasion ability. 
-
id 508.1m
distance 1.102388620376587
text 508.1m Any abilities that trigger on attackers being declared trigger. 
-
id 702.111a
distance 1.1382246017456055
text 702.111a Menace is an evasion ability. 
-
id 702.142b
distance 1.1396211385726929
text 702.142b Effects may refer to boast abilities. If an effect refers to a creature boasting, it means its boast ability being activated. 
-
id 702.36a
distance 1.1773734092712402
text 702.36a Fear is an evasion ability. 
-
id 702.12a
distance 1.1810660362243652
text 702.12a Indestructible is a static ability. 
-
id 702.118a
distance 1.186344861984253
text 702.118a Skulk is an evasion ability. 
-
id 702.90a
distance 1.1903629302978516
text 702.90a Infect is a static ability. 
-
id 702.9a
distance 1.1960893869400024
text 702.9a Flying is an evasion ability. 
-
id 702.14b
distance 1.198228120803833
text 702.14b Landwalk is an evasion ability. 
-
=
