In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# unset OMP_NUM_THREADS
import os

if "OMP_NUM_THREADS" in os.environ:
    del os.environ["OMP_NUM_THREADS"]

In [None]:
from pathlib import Path

from llm_ol.llm.cpu import load_mistral_instruct

In [None]:
model_path = Path("/ramdisks/mistral-7b-instruct-v0.2.Q4_K_M.gguf")
lm = load_mistral_instruct(model_path, n_threads=32, n_gpu_layers=-1)

In [None]:
from llm_ol.dataset.wikipedia import load_dataset

G = load_dataset("out/data/wikipedia/v1/full/full_graph.json", max_depth=1)

In [None]:
items = []
for _, data in G.nodes(data=True):
    for page in data["pages"]:
        items.append({"title": page["title"], "abstract": page["abstract"]})

In [None]:
import random
from llm_ol.experiments.prompting.create_hierarchy import create_hierarchy

item = random.choice(items)
out = lm + create_hierarchy(item["title"], item["abstract"])

In [None]:
from llm_ol.llm.templates.categorise_article import categorise_article_top_down

item

In [None]:
categories = set()

result = []

for _ in range(3):
    item = random.choice(items)
    out = lm + categorise_article_top_down(
        item["title"], item["abstract"], list(categories)
    )
    categories.update(out["cats"])
    result.append((item, out["cats"]))

In [None]:
import json

with open("out/experiments/prompting/dev/categoried_pages.jsonl") as f:
    results = [json.loads(line) for line in f]

In [None]:
from collections import defaultdict

categories = defaultdict(list)
for page in results:
    for cat in page["categories"]:
        categories[cat].append(page)

In [None]:
import matplotlib.pyplot as plt

plt.hist([len(v) for v in categories.values()], bins=20, log=True)

In [None]:
import random
import guidance

item = random.choice(items)
print(item)

s = """The following is an article's title and abstract. Your task is to assign this article to suitable category hierarchy. \
A category is typically represented by a word or a short phrase, representing broader topics/concepts that the article is about. \
A category hierarchy is a directed acyclic graph that starts with a detailed categorisation and becomes more and more \
general higher up the hierarchy, until it reaches the special base category "ROOT".

An example hierarchy for an article on "Addition" might be have the following category hierarchy:

```json
{
    "ROOT": {
        "Mathematics": {
            "Mathematical notation": "LEAF"
        },
        "Entities": {
            "Systems": {
                "Notation": {
                    "Mathematical notation": "LEAF"
                }
            }
        }
    }
}
```""" + """
Title: {title}
{abstract}
""".format(
    **item
)

with guidance.instruction():
    out = lm + s
out += "```json\n"
out += guidance.gen(name="hierarchy", max_tokens=500, stop="```")

In [None]:
import json

results = []
with open("out/experiments/prompting/dev-h/categoried_pages.jsonl") as f:
    for line in f:
        item = json.loads(line)
        try:
            item["hierarchy"] = json.loads(item["hierarchy"])
        except json.JSONDecodeError:
            print(f"Failed to parse hierarchy for {item['title']}")
            item["hierarchy"] = None
        results.append(item)

In [None]:
import networkx as nx

G = nx.DiGraph()


def walk_hierarchy(hierarchy: dict):
    for parent, sub_hierarchy in hierarchy.items():
        if sub_hierarchy == "LEAF":
            continue
        elif isinstance(sub_hierarchy, dict):
            for child in sub_hierarchy:
                G.add_edge(parent, child)
            walk_hierarchy(sub_hierarchy)
        else:
            print(f"Unknown type {parent} -> {sub_hierarchy}")


for item in results:
    if item["hierarchy"] is not None:
        walk_hierarchy(item["hierarchy"])

In [None]:
import random

# show random subgraphs
random_root = random.choice(list(G.nodes))
while not (5 < len(random_subgraph := nx.ego_graph(G, random_root, radius=2)) < 30):
    random_root = random.choice(list(G.nodes))
# fig, ax = plt.subplots(figsize=(6, 6))
# nx.draw_networkx(random_subgraph, with_labels=True, ax=ax, pos=nx.circular_layout(random_subgraph))
# ax.set(title=f"Random subgraph of {random_root}")

print(random_root)
A = nx.drawing.nx_agraph.to_agraph(random_subgraph)
A.layout("fdp")
A.draw(f"out/experiments/prompting/dev-h/visualisation/{random_root}.png")
A