In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# unset OMP_NUM_THREADS
import os

if "OMP_NUM_THREADS" in os.environ:
    del os.environ["OMP_NUM_THREADS"]

In [None]:
from pathlib import Path
import random
import networkx as nx
import json
from tqdm import tqdm
import numpy as np
import graph_tool.all as gt

from llm_ol.dataset import data_model
from llm_ol.dataset.wikipedia import ROOT_CATEGORY_ID
from llm_ol.utils.nx_to_gt import nx_to_gt

In [None]:
from openai import Client

client = Client(
    api_key="none",
    base_url="http://localhost:8080/v1",
)

In [None]:
import re
import textwrap
from llm_ol.experiments.llm.templates import RESPONSE_REGEX
from llm_ol.experiments.llm.prompting.create_hierarchy_v2 import template

prompt = template.render(
    title="Single whip law",
    abstract="""The Single whip law or the "Single whip reform" (simplified Chinese: 一条鞭法; traditional Chinese: 一條鞭法; pinyin: Yì Tiáo Biān Fǎ) was a fiscal law first instituted during the middle Ming dynasty, in the early 16th century, and then promulgated throughout the empire in 1580 by Zhang Juzheng.[1]
The measure aimed primarily to simplify the complex fiscal code under Ming law, by commuting most obligations towards the central government — from land and poll taxes to the labour obligations of the peasantry and the tributes of prefectural and county officials — into a single silver payment, at a level based on the population and cultivated land in each prefecture. Therefore, by reducing complexity, the Single Whip law reduced the costs of tax collection, while also increasing the tax base. """,
    examples=[
        {
            "title": "Stoning",
            "abstract": """Stoning, or lapidation, is a method of capital punishment where a group throws stones at a person until the subject dies from blunt trauma. It has been attested as a form of punishment for grave misdeeds since ancient history.
The Torah and Talmud prescribe stoning as punishment for a number of offenses. Over the centuries, Rabbinic Judaism developed a number of procedural constraints which made these laws practically unenforceable. Although stoning is not mentioned in the Quran, classical Islamic jurisprudence (fiqh) imposed stoning as a hadd (sharia-prescribed) punishment for certain forms of zina (illicit sexual intercourse) on the basis of hadith (sayings and actions attributed to the Islamic prophet Muhammad). It also developed a number of procedural requirements which made zina difficult to prove in practice.""",
            "paths": [
                [
                    "Main topic classifications",
                    "Human behavior",
                    "Abuse",
                    "Cruelty",
                    "Torture",
                ],
                ["Main topic classifications", "Human behavior", "Violence", "Torture"],
                [
                    "Main topic classifications",
                    "Law",
                    "Law-related events",
                    "Crimes",
                    "Torture",
                ],
                [
                    "Main topic classifications",
                    "Law",
                    "Legal aspects of death",
                    "Killings by type",
                ],
                ["Main topic classifications", "Society", "Violence", "Torture"],
            ],
        }
    ],
)

# print("\n".join(textwrap.wrap(prompt, width=100, replace_whitespace=False)))

# completion = client.chat.completions.create(
#     messages=[
#         {
#             "role": "user",
#             "content": prompt,
#         }
#     ],
#     model="gpt-3.5-turbo",
#     # extra_body={"guided_regex": RESPONSE_REGEX},
#     temperature=0,
#     max_tokens=128,
# )
# out = completion.choices[0].message.content
# print(out)
# print(re.fullmatch(RESPONSE_REGEX, out).group(0))

completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": prompt,
        }
    ],
    model="gpt-3.5-turbo",
    extra_body={"guided_regex": RESPONSE_REGEX},
    temperature=0.1,
    max_tokens=256,
)
print(completion.choices[0].message.content)

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/full/graph_depth_3.json")

In [None]:
items = {}
for node, data in G.nodes(data=True):
    for page in data["pages"]:
        if page["id"] not in items:
            items[page["id"]] = {**page, "categories": [node]}
        else:
            items[page["id"]]["categories"].append(node)

In [None]:
import random
from llm_ol.experiments.prompting.create_hierarchy import create_hierarchy

item = random.choice(list(items.values()))
out = lm + create_hierarchy(item["title"], item["abstract"])

In [None]:
categories = set()

result = []

for _ in range(3):
    item = random.choice(items)
    out = lm + categorise_article_top_down(
        item["title"], item["abstract"], list(categories)
    )
    categories.update(out["cats"])
    result.append((item, out["cats"]))

In [None]:
import json

with open("out/experiments/prompting/dev/categoried_pages.jsonl") as f:
    results = [json.loads(line) for line in f]

In [None]:
from collections import defaultdict

categories = defaultdict(list)
for page in results:
    for cat in page["categories"]:
        categories[cat].append(page)

In [None]:
import matplotlib.pyplot as plt

plt.hist([len(v) for v in categories.values()], bins=20, log=True)

In [None]:
import random
import guidance

item = random.choice(list(items.values()))
print(item)

s = """The following is an article's title and abstract. Your task is to assign this article to suitable category hierarchy. \
A category is typically represented by a word or a short phrase, representing broader topics/concepts that the article is about. \
A category hierarchy is a directed acyclic graph that starts with a detailed categorisation and becomes more and more \
general higher up the hierarchy, until it reaches the special base category "Main topic classification".

An example hierarchy for an article on "Single whip law" might be have the following category hierarchy:

```txt
Main topic classifications -> Economy -> Economic history -> History of taxation
Main topic classifications -> Law -> Law by issue -> Legal history by issue -> History of taxation
Main topic classifications -> Law -> Law by issue -> Tax law
Main topic classifications -> Law -> Law stubs -> Asian law stubs
Main topic classifications -> Politics -> Political history -> History of taxation
```

Another example hierarchy for an article on "Stoning" is:

```txt
Main topic classifications -> Human behavior -> Abuse -> Cruelty -> Torture
Main topic classifications -> Human behavior -> Violence -> Torture
Main topic classifications -> Law -> Law-related events -> Crimes -> Torture
Main topic classifications -> Law -> Legal aspects of death -> Killings by type
Main topic classifications -> Society -> Violence -> Torture
```""" + """

Title: {title}
{abstract}

Provide a category hierarchy for the above article. Use the same format as the examples above.
""".format(
    **item
)

with guidance.instruction():
    out = lm + s
out += "```txt\n"
out += guidance.gen(name="hierarchy", max_tokens=500, stop="```")

In [None]:
import json

results = []
with open("out/experiments/prompting/dev-h/categoried_pages.jsonl") as f:
    for line in f:
        item = json.loads(line)
        try:
            item["hierarchy"] = json.loads(item["hierarchy"])
        except json.JSONDecodeError:
            print(f"Failed to parse hierarchy for {item['title']}")
            item["hierarchy"] = None
        results.append(item)

In [None]:
import networkx as nx

G = nx.DiGraph()


def walk_hierarchy(hierarchy: dict):
    for parent, sub_hierarchy in hierarchy.items():
        if sub_hierarchy == "LEAF":
            continue
        elif isinstance(sub_hierarchy, dict):
            for child in sub_hierarchy:
                G.add_edge(parent, child)
            walk_hierarchy(sub_hierarchy)
        else:
            print(f"Unknown type {parent} -> {sub_hierarchy}")


for item in results:
    if item["hierarchy"] is not None:
        walk_hierarchy(item["hierarchy"])

In [None]:
import random

# show random subgraphs
random_root = random.choice(list(G.nodes))
while not (5 < len(random_subgraph := nx.ego_graph(G, random_root, radius=2)) < 30):
    random_root = random.choice(list(G.nodes))
# fig, ax = plt.subplots(figsize=(6, 6))
# nx.draw_networkx(random_subgraph, with_labels=True, ax=ax, pos=nx.circular_layout(random_subgraph))
# ax.set(title=f"Random subgraph of {random_root}")

print(random_root)
A = nx.drawing.nx_agraph.to_agraph(random_subgraph)
A.layout("fdp")
A.draw(f"out/experiments/prompting/dev-h/visualisation/{random_root}.png")
A

In [None]:
random_leaf = random.choice(list(G.nodes))
print(G.nodes[random_leaf]["title"])

i = 0
for path in nx.shortest_simple_paths(G, ROOT_CATEGORY_ID, random_leaf):
    names = [G.nodes[node]["title"] for node in path]
    print(" -> ".join(names))
    i += 1
    if i > 10:
        break

In [None]:
def hierarchy(node, n: int):
    paths = []
    G_sub = nx.DiGraph()
    i = 0
    for path in nx.shortest_simple_paths(G, ROOT_CATEGORY_ID, node):
        names = [G.nodes[node]["title"] for node in path]
        paths.append(names)
        for parent, child in zip(names[:-1], names[1:]):
            G_sub.add_edge(parent, child)
        i += 1
        if i > n:
            break
    return paths

In [None]:
def paths_to_root(G_gt: gt.Graph, nx_to_gt_map, gt_to_nx_map, page, cutoff=None):
    page_node = G_gt.add_vertex()
    for category in page["categories"]:
        G_gt.add_edge(nx_to_gt_map[category], page_node)

    try:
        paths = []
        for i, path in enumerate(
            # nx.all_simple_paths(G, ROOT_CATEGORY_ID, page["id"], cutoff=cutoff)
            gt.all_paths(
                G_gt,
                source=nx_to_gt_map[ROOT_CATEGORY_ID],
                target=page_node,
                cutoff=cutoff,
            )
        ):
            names = tuple(G.nodes[gt_to_nx_map[node]]["title"] for node in path[:-1])
            paths.append(names)
    finally:
        G_gt.remove_vertex(page_node)

    random.shuffle(paths)
    return paths

In [None]:
G_gt, nx_to_gt_map, gt_to_nx_map = nx_to_gt(G)

item = random.choice(list(items.values()))
print(item["title"])
for path in paths_to_root(G_gt, nx_to_gt_map, gt_to_nx_map, item, cutoff=3):
    print(" -> ".join(path))

In [None]:
G_gt, nx_to_gt_map, gt_to_nx_map = nx_to_gt(G)
n = len(items)
results = []

print(f"Sample {n}/{len(items)} items")
for item in tqdm(random.sample(list(items.values()), n)):
    results.append(paths_to_root(G_gt, nx_to_gt_map, gt_to_nx_map, item, cutoff=6))

In [None]:
def coverage(results):
    all_edges = G.edges()
    all_edges = {(G.nodes[u]["title"], G.nodes[v]["title"]) for u, v in all_edges}

    for paths in results:
        for path in paths:
            for parent, child in zip(path[:-1], path[1:]):
                all_edges.discard((parent, child))

    return 1 - len(all_edges) / len(G.edges())


coverage(results)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 3))

sns.histplot(
    [len(ps) for ps in results], bins=10, log_scale=True, ax=ax1, stat="density"
)
ax1.set(xlabel="Number of paths")
sns.histplot(
    [len(p) for ps in results for p in ps], discrete=True, ax=ax2, stat="density"
)
ax2.set(xlabel="Path length")

xs = np.linspace(1, len(results), 11, dtype=int)
sns.lineplot(x=xs, y=[coverage(results[:i]) for i in xs], ax=ax3, marker="o")
ax3.set(xlabel="Number of samples", ylabel="Coverage", ylim=(0, 1))

fig.tight_layout()
fig.savefig("out/graphs/cutoff_5_depth_3_n_paths.png", dpi=144)

In [None]:
all_edges = G.edges()
not_covered = {(G.nodes[u]["title"], G.nodes[v]["title"]) for u, v in all_edges}
for paths in results:
    for path in paths:
        for parent, child in zip(path[:-1], path[1:]):
            not_covered.discard((parent, child))

not_covered

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/train_test_split/train_graph.json")

In [None]:
id_ = 31686682

path = nx.shortest_path(G, G.graph["root"], id_)
[G.nodes[node]["title"] for node in path]

In [None]:
G_test = nx.empty_graph(3)

nx.multi_source_dijkstra_path_length(G_test, {0, 1})