In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# unset OMP_NUM_THREADS
import os

if "OMP_NUM_THREADS" in os.environ:
    del os.environ["OMP_NUM_THREADS"]

In [None]:
from pathlib import Path

from llm_ol.llm.cpu import load_mistral_instruct

In [None]:
model_path = Path("/ramdisks/mistral-7b-instruct-v0.2.Q4_K_M.gguf")
lm = load_mistral_instruct(model_path, n_threads=8, n_gpu_layers=0, use_mlock=True)

In [None]:
from llm_ol.dataset import data_model

G = data_model.load_graph("out/data/wikipedia/v1/full/graph_depth_3.json")

In [None]:
items = {}
for node, data in G.nodes(data=True):
    for page in data["pages"]:
        if page["id"] not in items:
            items[page["id"]] = {**page, "categories": [node]}
        else:
            items[page["id"]]["categories"].append(node)

In [None]:
import random
from llm_ol.experiments.prompting.create_hierarchy import create_hierarchy

item = random.choice(list(items.values()))
out = lm + create_hierarchy(item["title"], item["abstract"])

In [None]:
categories = set()

result = []

for _ in range(3):
    item = random.choice(items)
    out = lm + categorise_article_top_down(
        item["title"], item["abstract"], list(categories)
    )
    categories.update(out["cats"])
    result.append((item, out["cats"]))

In [None]:
import json

with open("out/experiments/prompting/dev/categoried_pages.jsonl") as f:
    results = [json.loads(line) for line in f]

In [None]:
from collections import defaultdict

categories = defaultdict(list)
for page in results:
    for cat in page["categories"]:
        categories[cat].append(page)

In [None]:
import matplotlib.pyplot as plt

plt.hist([len(v) for v in categories.values()], bins=20, log=True)

In [None]:
import random
import guidance

item = random.choice(list(items.values()))
print(item)

s = """The following is an article's title and abstract. Your task is to assign this article to suitable category hierarchy. \
A category is typically represented by a word or a short phrase, representing broader topics/concepts that the article is about. \
A category hierarchy is a directed acyclic graph that starts with a detailed categorisation and becomes more and more \
general higher up the hierarchy, until it reaches the special base category "Main topic classification".

An example hierarchy for an article on "Single whip law" might be have the following category hierarchy:

```txt
Main topic classifications -> Economy -> Economic history -> History of taxation
Main topic classifications -> Law -> Law by issue -> Legal history by issue -> History of taxation
Main topic classifications -> Law -> Law by issue -> Tax law
Main topic classifications -> Law -> Law stubs -> Asian law stubs
Main topic classifications -> Politics -> Political history -> History of taxation
```

Another example hierarchy for an article on "Stoning" is:

```txt
Main topic classifications -> Human behavior -> Abuse -> Cruelty -> Torture
Main topic classifications -> Human behavior -> Violence -> Torture
Main topic classifications -> Law -> Law-related events -> Crimes -> Torture
Main topic classifications -> Law -> Legal aspects of death -> Killings by type
Main topic classifications -> Society -> Violence -> Torture
```""" + """

Title: {title}
{abstract}

Provide a category hierarchy for the above article. Use the same format as the examples above.
""".format(
    **item
)

with guidance.instruction():
    out = lm + s
out += "```txt\n"
out += guidance.gen(name="hierarchy", max_tokens=500, stop="```")

In [None]:
import json

results = []
with open("out/experiments/prompting/dev-h/categoried_pages.jsonl") as f:
    for line in f:
        item = json.loads(line)
        try:
            item["hierarchy"] = json.loads(item["hierarchy"])
        except json.JSONDecodeError:
            print(f"Failed to parse hierarchy for {item['title']}")
            item["hierarchy"] = None
        results.append(item)

In [None]:
import networkx as nx

G = nx.DiGraph()


def walk_hierarchy(hierarchy: dict):
    for parent, sub_hierarchy in hierarchy.items():
        if sub_hierarchy == "LEAF":
            continue
        elif isinstance(sub_hierarchy, dict):
            for child in sub_hierarchy:
                G.add_edge(parent, child)
            walk_hierarchy(sub_hierarchy)
        else:
            print(f"Unknown type {parent} -> {sub_hierarchy}")


for item in results:
    if item["hierarchy"] is not None:
        walk_hierarchy(item["hierarchy"])

In [None]:
import random

# show random subgraphs
random_root = random.choice(list(G.nodes))
while not (5 < len(random_subgraph := nx.ego_graph(G, random_root, radius=2)) < 30):
    random_root = random.choice(list(G.nodes))
# fig, ax = plt.subplots(figsize=(6, 6))
# nx.draw_networkx(random_subgraph, with_labels=True, ax=ax, pos=nx.circular_layout(random_subgraph))
# ax.set(title=f"Random subgraph of {random_root}")

print(random_root)
A = nx.drawing.nx_agraph.to_agraph(random_subgraph)
A.layout("fdp")
A.draw(f"out/experiments/prompting/dev-h/visualisation/{random_root}.png")
A

In [None]:
import random
import networkx as nx
from llm_ol.dataset.wikipedia import ROOT_CATEGORY_ID

In [None]:
random_leaf = random.choice(list(G.nodes))
print(G.nodes[random_leaf]["title"])

i = 0
for path in nx.shortest_simple_paths(G, ROOT_CATEGORY_ID, random_leaf):
    names = [G.nodes[node]["title"] for node in path]
    print(" -> ".join(names))
    i += 1
    if i > 10:
        break

In [None]:
def hierarchy(node, n: int):
    paths = []
    G_sub = nx.DiGraph()
    i = 0
    for path in nx.shortest_simple_paths(G, ROOT_CATEGORY_ID, node):
        names = [G.nodes[node]["title"] for node in path]
        paths.append(names)
        for parent, child in zip(names[:-1], names[1:]):
            G_sub.add_edge(parent, child)
        i += 1
        if i > n:
            break
    return paths

In [None]:
def paths_to_root(page, n: int):
    for category in page["categories"]:
        G.add_edge(category, page["id"])

    try:
        paths = []
        for i, path in enumerate(
            nx.shortest_simple_paths(G, ROOT_CATEGORY_ID, page["id"])
        ):
            names = tuple(G.nodes[node]["title"] for node in path[:-1])
            paths.append(names)
            if i > n:
                break
    finally:
        G.remove_node(page["id"])

    # sort lexicographically
    return sorted(paths, key=lambda x: x)

In [None]:
import json

item = random.choice(list(items.values()))
print(item["title"])
for path in paths_to_root(item, 3):
    print(" -> ".join(path))
# n = random.choice(list(G.nodes))
# print(G.nodes[n]["title"])

# # G_sub = hierarchy(n, 5)
# # A = nx.nx_agraph.to_agraph(G_sub)
# # print(A.to_string())

# for path in hierarchy(n, 5):
#     print(" -> ".join(path))

# # print(nx.to_latex_raw(G_sub))

# # A.layout("dot")
# # A