In [4]:
import json
from pathlib import Path
from collections import defaultdict, Counter
import numpy as np
import re
import hashlib
from tqdm import tqdm

In [4]:
disambig = Path("../results/disambiguate_merges/merges_with_ids/round_5/results/all_merges_disambig.jsonl")
counts = Path("../results/disambiguate_merges/traced_lineages/final_enriched_counts.jsonl")

dis = []
c = []

with disambig.open() as f:
    for line in f:
        obj = json.loads(line)
        dis.append(obj["id"])


with counts.open() as f:
    for line in f:
        obj = json.loads(line)
        c.append(obj["id"])

In [9]:
og_data = Path("../results/new_descriptors/all_descriptors_new.jsonl")

def _normalize_descriptor(s: str) -> str:
    return re.sub(r"[_\s]+", " ", (s or "")).strip().lower()


def pair_id(descriptor: str, explainer: str, *, length: int = 12) -> str:
    key = f"{_normalize_descriptor(descriptor)}\u241f{explainer}"
    return hashlib.sha1(key.encode("utf-8")).hexdigest()[:length]

def _split_pair_raw(text: str):
    """Split raw "descriptor;explainer" string.
    Mirrors split logic in extract_descriptor_groups.py.
    """
    try:
        d, e = text.split(";", 1)
        return _normalize_descriptor(d), e.strip()
    except ValueError:
        return "", ""

def load_original_counts_from_file(path: Path):
    """Return counts of *original* (possibly duplicated) pairs keyed by pair_id.

    Auto-detects format per line:
      - RAW extractor lines: keys {"similarity", "descriptors"}
      - PROCESSED lines: keys {"descriptor", "explainer"}
    """
    all_pids = []

    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            best_idx = int(np.argmax(row["similarity"]))
            descriptors = row["descriptors"][best_idx]
            for desc_exp in descriptors:
                d, e = _split_pair_raw(desc_exp)
                if not d or not e:
                    continue
                pid = pair_id(d, e)
                all_pids.append(pid)

    return all_pids

og_pids = load_original_counts_from_file(og_data)

In [15]:
def get_input_pids(round_num):
    if round_num == 0:
        og_data = Path("../results/new_descriptors/all_descriptors_new.jsonl")
        pids = load_original_counts_from_file(og_data)
    else:
        pids = []
        disambig = Path(f"../results/disambiguate_merges/merges_with_ids/round_{round_num}/results/all_merges_disambig.jsonl")
        with disambig.open() as f:
            for line in f:
                obj = json.loads(line)
                pids.append(obj["id"])

    return pids

def get_result_pids(round_num):
    lineage = Path(f"../results/disambiguate_merges/merges_with_ids/round_{round_num}/results/all_merges_full_lineage.jsonl")
    disambig = Path(f"../results/disambiguate_merges/merges_with_ids/round_{round_num}/results/all_merges_disambig.jsonl")
    
    lin = []
    dis = []
    
    with lineage.open() as f:
        for line in f:
            obj = json.loads(line)
            lin.extend(obj["source_pair_ids"])
    
    with disambig.open() as f:
        for line in f:
            obj = json.loads(line)
            dis.append(obj["id"])

    return lin, dis

def check_coverage(input_pids, lin, dis):
    lin_set = set(lin)
    dis_set = set(dis)
    ogs = set(input_pids)
    
    in_lin = len(ogs & lin_set)
    in_dis = len((ogs - lin_set) & dis_set)   # only those not already in lin
    dropped = list(ogs - lin_set - dis_set)
    for pid in dropped:
        print(f"Pid {pid} dropped!")



In [16]:
for i in tqdm(range(0,5)):
    input_pids = get_input_pids(i)
    lin, dis = get_result_pids(i+1)
    check_coverage(input_pids, lin, dis)

100%|██████████| 5/5 [01:40<00:00, 20.04s/it]


In [5]:
counts_path = Path("/scratch/project_462000963/users/tarkkaot/LLM_document_descriptors/results/disambiguate_merges/traced_lineages/final_enriched_counts.jsonl")
disambig_path = Path("/scratch/project_462000963/users/tarkkaot/LLM_document_descriptors/results/disambiguate_merges/merges_with_ids/round_5/results/all_merges_disambig.jsonl")

In [7]:
counts_ids = []
disambig_ids = []

with counts_path.open() as f:
    for line in f:
        obj = json.loads(line)
        counts_ids.append(obj["id"])

with disambig_path.open() as f:
    for line in f:
        obj = json.loads(line)
        disambig_ids.append(obj["id"])

In [8]:
Counter(counts_ids) == Counter(disambig_ids)

True

In [27]:
source_pids = []
with counts_path.open() as f:
    for line in f:
        obj = json.loads(line)
        source_pids.extend(list(obj["original_counts"].keys()))

In [20]:
search_pid = og_pids[123135]

for idx, pid in enumerate(source_pids):
    if pid == search_pid:
        print(idx)
        break

3809584


In [55]:
"542b30d658d6" in source_pids

True

In [30]:
len(source_pids)

7403667

In [26]:
from datasets import load_dataset
import json
from pathlib import Path

In [34]:
ds = load_dataset("stanfordnlp/imdb", split="train", streaming=False)

In [33]:
l = 0
for doc in ds:
    l += 1

l

25000

In [35]:
ds

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [12]:
imdb = Path("../results/benchmarks/imdb/descriptors_imdb.jsonl")


data = []

with imdb.open() as f:
    for line in f:
        obj = json.loads(line)
        data.append(obj)

In [16]:
seen = set()
final = []
for doc in data:
    if doc["document"] not in seen:
        seen.add(doc["document"])
        final.append(doc)

In [20]:
len(seen)

24904

In [24]:
seen = set()
dups = []
for doc in ds["text"]:
    if doc not in seen:
        seen.add(doc)
    else:
        dups.append(doc)