In [1]:
import json
from pathlib import Path

path = Path("Dataset/causenet-precision.jsonl")

SAMPLE_N = 3

data = []
with path.open("r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= SAMPLE_N:
            break
        data.append(json.loads(line))

In [2]:
data[0]

{'causal_relation': {'cause': {'concept': 'accident'},
  'effect': {'concept': 'death'}},
 'sources': [{'type': 'wikipedia_sentence',
   'payload': {'wikipedia_page_id': '45710',
    'wikipedia_page_title': 'Forensic science',
    'wikipedia_revision_id': '861543252',
    'wikipedia_revision_timestamp': '2018-09-28T06:01:51Z',
    'sentence_section_heading': 'Origins of forensic science and early methods',
    'sentence_section_level': '3',
    'sentence': 'For example, the book also described how to distinguish between a drowning (water in the lungs) and strangulation (broken neck cartilage), along with other evidence from examining corpses on determining if a death was caused by murder, suicide or an accident.',
    'path_pattern': '[[cause]]/N\t-nmod:agent\tcaused/VBN\t+nsubjpass\t[[effect]]/N'}},
  {'type': 'wikipedia_sentence',
   'payload': {'wikipedia_page_id': '124363',
    'wikipedia_page_title': 'Goodsprings, Nevada',
    'wikipedia_revision_id': '857283240',
    'wikipedia_r

In [3]:
from collections import defaultdict, deque

adj = defaultdict(set)
radj = defaultdict(set)

with path.open("r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        obj = json.loads(line)
        cr = obj["causal_relation"]
        cause = cr["cause"]["concept"]
        effect = cr["effect"]["concept"]
        adj[cause].add(effect)
        radj[effect].add(cause)

adj = {k: sorted(v) for k, v in adj.items()}
radj = {k: sorted(v) for k, v in radj.items()}

def reverse_bfs_distance(radj, target):
    dist = {target: 0}
    q = deque([target])
    while q:
        node = q.popleft()
        for prev in radj.get(node, []):
            if prev not in dist:
                dist[prev] = dist[node] + 1
                q.append(prev)
    return dist

dist_to_death = reverse_bfs_distance(radj, "death")
print("graph built:", "causes", len(adj), "effects", len(radj), "nodes_reaching_death", len(dist_to_death))

graph built: causes 51863 effects 38324 nodes_reaching_death 39370


In [4]:
TARGET = "death"
MIN_HOPS = 3
MAX_HOPS = 4

layer = {TARGET}
exact = {}

for d in range(1, MAX_HOPS + 1):
    prev = set()
    for node in layer:
        prev.update(radj.get(node, []))
    layer = prev
    exact[d] = layer
    print(f"exact_hops={d}: {len(layer)} heads")

heads_union = set()
for d in range(MIN_HOPS, MAX_HOPS + 1):
    heads_union.update(exact.get(d, set()))
heads_union.discard(TARGET)

heads_any = set(dist_to_death.keys())
heads_any.discard(TARGET)

print(f"unique (head -> {TARGET}) pairs with hops in [{MIN_HOPS}, {MAX_HOPS}]: {len(heads_union)}")
print(f"unique (head -> {TARGET}) pairs (any hops): {len(heads_any)}")
print("contains tobacco_smoking:", "tobacco_smoking" in heads_union)

exact_hops=1: 5168 heads
exact_hops=2: 30791 heads
exact_hops=3: 38558 heads
exact_hops=4: 39291 heads
unique (head -> death) pairs with hops in [3, 4]: 39290
unique (head -> death) pairs (any hops): 39369
contains tobacco_smoking: True


In [5]:
import json
import random
from pathlib import Path
import re
from collections import defaultdict, deque



MIN_HOPS = 4
MAX_HOPS = 4
SAMPLE_K = 120
SEED = 0

nodes = sorted(set(adj.keys()) | set(radj.keys()))
node2id = {n: i for i, n in enumerate(nodes)}
id2node = nodes

adj_i = [[] for _ in range(len(nodes))]
for cause, effects in adj.items():
    uid = node2id[cause]
    adj_i[uid] = [node2id[e] for e in effects]

radj_i = [[] for _ in range(len(nodes))]
for uid in range(len(nodes)):
    for vid in adj_i[uid]:
        radj_i[vid].append(uid)

stop = {
    'a','an','the','and','or','of','to','in','on','for','with','without','as','at','by','from','into','over','under',
    'is','are','was','were','be','been','being','due','after','before','during','between','within','among',
    'more','less','than'
}

def base_tokens(s: str):
    parts = re.split(r'[^A-Za-z0-9]+', s.lower())
    return [p for p in parts if p and p not in stop and len(p) >= 3]

def tokenize(s: str):
    toks = set()
    for t in base_tokens(s):
        toks.add(t)
        if len(t) >= 5:
            toks.add(t[:3])
            toks.add(t[-3:])
        if len(t) >= 8:
            toks.add(t[:4])
            toks.add(t[-4:])
    return toks

tokens_by_id = [tokenize(name) for name in id2node]
token_to_ids = defaultdict(list)
for i, toks in enumerate(tokens_by_id):
    for t in toks:
        token_to_ids[t].append(i)

seen_fwd = [0] * len(nodes)
seen_bwd = [0] * len(nodes)
mark = 0

def bfs(start: int, graph, seen, mark: int):
    q = deque([start])
    seen[start] = mark
    while q:
        u = q.popleft()
        for v in graph[u]:
            if seen[v] != mark:
                seen[v] = mark
                q.append(v)

rng = random.Random(SEED)
used_unrelated = set()

def pick_two_simple_paths_len4(uid):
    tails = {}
    for a in adj_i[uid]:
        if a == uid:
            continue
        for b in adj_i[a]:
            if b in (uid, a):
                continue
            for c in adj_i[b]:
                if c in (uid, a, b):
                    continue
                for tid in adj_i[c]:
                    if tid in (uid, a, b, c):
                        continue
                    p = [uid, a, b, c, tid]
                    lst = tails.setdefault(tid, [])
                    if len(lst) < 2:
                        lst.append(p)
                        if len(lst) == 2:
                            return tid, lst[0], lst[1]
    return None, None, None

head_ids = list(range(len(nodes)))
rng.shuffle(head_ids)

selected_pairs = []
for uid in head_ids:
    if len(selected_pairs) >= SAMPLE_K:
        break
    if not adj_i[uid]:
        continue

    tid, p1, p2 = pick_two_simple_paths_len4(uid)
    if tid is None:
        continue

    mark += 1
    bfs(uid, adj_i, seen_fwd, mark)
    bfs(uid, radj_i, seen_bwd, mark)

    head_tokens = tokens_by_id[uid]
    cand = set()
    for t in head_tokens:
        cand.update(token_to_ids.get(t, ()))
    cand.discard(uid)

    scored = []
    for cid in cand:
        if seen_fwd[cid] == mark or seen_bwd[cid] == mark:
            continue
        overlap = head_tokens & tokens_by_id[cid]
        if not overlap:
            continue
        scored.append(((sum(len(x) for x in overlap), len(overlap), rng.random()), cid))

    scored.sort(reverse=True)
    neg = None
    for _, cid in scored:
        if cid not in used_unrelated:
            neg = cid
            break
    if neg is None and scored:
        neg = scored[0][1]

    if neg is None:
        pool = [i for i in range(len(nodes)) if i != uid and seen_fwd[i] != mark and seen_bwd[i] != mark]
        rng.shuffle(pool)
        for cid in pool:
            if cid not in used_unrelated:
                neg = cid
                break
        if neg is None:
            neg = pool[0]

    used_unrelated.add(neg)
    selected_pairs.append(
        {
            'head': id2node[uid],
            'tail': id2node[tid],
            'min_hops': MIN_HOPS,
            'max_hops': MAX_HOPS,
            'chain_1': [id2node[x] for x in p1],
            'chain_2': [id2node[x] for x in p2],
            'unrelated_to_head': id2node[neg],
        }
    )

if len(selected_pairs) < SAMPLE_K:
    raise RuntimeError(f'Only found {len(selected_pairs)} pairs')

out_path = Path(f"results/causenet_head_tail_pairs_min{MIN_HOPS}_max{MAX_HOPS}_{SAMPLE_K}.jsonl")
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
    for rec in selected_pairs:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print("min_hops", MIN_HOPS, "max_hops", MAX_HOPS)
print("sample_pairs", len(selected_pairs))
print("saved", out_path)
selected_pairs[:2]

min_hops 4 max_hops 4
sample_pairs 120
saved results\causenet_head_tail_pairs_min4_max4_120.jsonl


[{'head': 'butterfly',
  'tail': 'anxiety',
  'min_hops': 4,
  'max_hops': 4,
  'chain_1': ['butterfly', 'hurricane', 'changes', 'abuse', 'anxiety'],
  'chain_2': ['butterfly', 'hurricane', 'changes', 'accidents', 'anxiety'],
  'unrelated_to_head': "flap_of_a_butterfly_'s_wings_in_brazil"},
 {'head': 'metabolic_rate',
  'tail': 'complications',
  'min_hops': 4,
  'max_hops': 4,
  'chain_1': ['metabolic_rate',
   'fat_loss',
   'weight_loss',
   'anemia',
   'complications'],
  'chain_2': ['metabolic_rate',
   'fat_loss',
   'weight_loss',
   'anorexia',
   'complications'],
  'unrelated_to_head': 'low_metabolic_rate'}]

In [6]:
START = "tobacco_smoking"
MID = "lung_cancer"
TARGET = "death"

MIN_HOPS = 4
MAX_HOPS = 6
MAX_SHOW_PER_HOPS = 20

prefix = [START, MID]
if MID not in adj.get(START, []):
    raise ValueError(f"Missing edge: {START} -> {MID}")

def iter_simple_paths_exact_hops(adj, dist_to_target, start_path, target, total_hops, limit=None):
    stack = [(start_path[-1], start_path, set(start_path))]
    found = 0
    while stack:
        node, path_nodes, seen = stack.pop()
        hops = len(path_nodes) - 1

        d = dist_to_target.get(node)
        if d is None or d > total_hops - hops:
            continue

        if hops == total_hops:
            if node == target:
                yield path_nodes
                found += 1
                if limit is not None and found >= limit:
                    return
            continue

        for nxt in reversed(adj.get(node, [])):
            if nxt in seen:
                continue
            stack.append((nxt, path_nodes + [nxt], seen | {nxt}))

paths_by_hops = {}
for hops in range(MIN_HOPS, MAX_HOPS + 1):
    paths = list(iter_simple_paths_exact_hops(adj, dist_to_death, prefix, TARGET, hops, limit=MAX_SHOW_PER_HOPS))
    paths_by_hops[hops] = paths
    print(f"hops={hops}: showing {len(paths)} paths")
    for p in paths:
        print("  " + " -> ".join(p))

hops=4: showing 20 paths
  tobacco_smoking -> lung_cancer -> brain_metastasis -> symptoms -> death
  tobacco_smoking -> lung_cancer -> cancer -> abnormalities -> death
  tobacco_smoking -> lung_cancer -> cancer -> aging -> death
  tobacco_smoking -> lung_cancer -> cancer -> ailments -> death
  tobacco_smoking -> lung_cancer -> cancer -> anaemia -> death
  tobacco_smoking -> lung_cancer -> cancer -> anemia -> death
  tobacco_smoking -> lung_cancer -> cancer -> anorexia -> death
  tobacco_smoking -> lung_cancer -> cancer -> anxiety -> death
  tobacco_smoking -> lung_cancer -> cancer -> arthritis -> death
  tobacco_smoking -> lung_cancer -> cancer -> asbestos -> death
  tobacco_smoking -> lung_cancer -> cancer -> asbestos_exposure -> death
  tobacco_smoking -> lung_cancer -> cancer -> bacteria -> death
  tobacco_smoking -> lung_cancer -> cancer -> birth_defects -> death
  tobacco_smoking -> lung_cancer -> cancer -> bleeding -> death
  tobacco_smoking -> lung_cancer -> cancer -> blindness 