In [None]:
import sys
from importlib import reload
from pathlib import Path

from rich.console import Console

sys.path.append("..")

import src.util as util
from src import text_process
from src import plotly_plots as pp
from src import data_functions as datafun

reload(pp)
reload(util)
reload(text_process)
cons = Console()

In [None]:
examples = util.load_dataset_parallel(
    Path("../data/dataset.ndjson"),
)
print(f"loaded {len(examples)} examples")
df = util.dataset_to_df(examples)
df.tail()

In [None]:
vocab_tokens = df["tokens"].explode().unique().sort().to_list()
tag_tokens = df["tags"].explode().unique().sort().to_list()
print(f"{len(vocab_tokens)=}")
print(f"{len(tag_tokens)=}")

In [None]:
dummy_docs = [
    "hello world hello world",
    "goodbye cruel world goodbye everyone",
    "hello everyone what is going on hello world hello",
]

In [None]:
N = 3

overlap_tags, high_tag_ovr = datafun.overlap_pairwise_simple(df["tags"].to_list(), N)
overlap_tokens, high_token_ovr = datafun.overlap_pairwise_simple(
    df["tokens"].to_list(), N
)
print(f"high overlap by tags: {len(high_tag_ovr)} pairs")
print(f"high overlap by token: {len(high_token_ovr)} pairs")

for i, j, ovr in high_token_ovr:
    print(f"\n{str((i, j)).ljust(12)}: {ovr:.1%} overlap")
    cons.print(df[i]["tokens"].list.join("")[0], style="cyan", highlight=False)
    cons.print(df[j]["tokens"].list.join("")[0], style="yellow", highlight=False)

In [None]:
reload(pp)


pp.heatmap(overlap_tags).show()
pp.heatmap(overlap_tokens).show()

In [None]:
from collections.abc import Sequence
from time import time

reload(datafun)


def overlap_pairwise_optimized(
    docs: Sequence[list[str]], n: int = 3, thr=0.5, n_jobs: int = 1
):
    """Compare n-gram overlap for all document pairs"""

    high = []

    # store all ngram sets ahead of time to avoid recomputing
    # shouldnt need too much memory...
    ngram_sets = datafun.get_ngrams_all(docs, n, n_jobs)
    for i in range(len(docs)):
        for j in range(i):
            overlap = datafun.get_overlap(ngram_sets[i], ngram_sets[j])

            # keep track of highest
            if overlap >= thr:
                high.append((i, j, overlap))

    # sort by descending overlap
    high.sort(key=lambda t: -t[-1])
    return None, high


ts = time()
ovr_mat_tok_opt, high_ovr_opt = overlap_pairwise_optimized(
    df["tokens"].to_list(), N, n_jobs=2
)
print(f"time: {1000 * (time() - ts):.1f} ms")

# pp.heatmap(ovr_mat_tok_opt).show()
