In [None]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

# Define helper functions

In [None]:
from functools import partial
import multiprocessing as mp
from pathlib import Path

import numpy as np
from datasketch import MinHash, LeanMinHash, MinHashLSH
from joblib import load
from nltk import ngrams
from tqdm.auto import tqdm

from utils.constants import DATA_DIR, OUTPUT_DIR

In [None]:
MinHash()

In [None]:
def make_corpus_iter(corpus_dir: Path):
    files = sorted([file for file in corpus_dir.iterdir() if file.suffix == '.joblib'])

    i = 0
    for file in files:
        docs = load(file)

        # Load filenames or ids
        filenames_file = file.with_name(f'{file.stem}_filenames.txt')
        doc_ids = (
            filenames_file.read_text().split()
            if filenames_file.exists()
            else map(lambda idx: f'{file.stem}-{idx}', range(len(docs)))
        )

        print("Loading file:", file)
        for doc_id, doc in zip(doc_ids, docs):
            # Yield name and doc
            yield doc_id, doc
            i += 1

In [None]:
def make_minhash_mapping(item, shingles: int, num_perm: int):
    doc_id, doc = item
    
    # Create MinHash
    shingles_set = set(ngrams(doc, 5))
    m = MinHash(num_perm=num_perm)
    for s in shingles_set:
        s = ''.join(s).encode('utf8')
        m.update(s)
     
    # Convert to LeanMinHash
    m = LeanMinHash(m)

    return doc_id, m

In [None]:
def parallel_create_minhashes(corpus_iter, shingles: int, num_perm: int, n_jobs: int, chunksize = 1000):
    make_minhash_mapping_ = partial(make_minhash_mapping, shingles=shingles, num_perm=num_perm)
    
    with mp.Pool(n_jobs) as pool:
        yield from pool.imap(make_minhash_mapping_, corpus_iter, chunksize=chunksize)

# Create MinHashLSH for WebText

In [None]:
NUM_PERM = 128
SHINGLES = 5

In [None]:
JACCARD = 0.9
lsh = MinHashLSH(threshold=JACCARD, num_perm=NUM_PERM)

In [None]:
wt_len = 8_282_020
wt_iter = make_corpus_iter(DATA_DIR / 'detokenized_webtext')

mh_iter = parallel_create_minhashes(wt_iter, total=wt_len, shingles=SHINGLES, num_perm=NUM_PERM, n_jobs=96)
wt_minhashes = {}

with lsh.insertion_session() as session:
    for key, minhash in tqdm(mh_iter, total=wt_len):
        wt_minhashes[key] = minhash
        session.insert(key, minhash, check_duplication=False)  # All keys are unique doc ids

# Create MinHashes for OpenWebText

In [None]:
# Create MinHash LSH with WebText
with mp.Pool(96) as pool:
    with lsh.insertion_session() as session:
        for key, minhash in tqdm(mh_iter, total=wt_len):
            session.insert(key, minhash)

In [None]:
import pickle
with open(OUTPUT_DIR / 'datasketch_v1.pkl', 'wb') as f:
    pickle.dump(lsh, f)

In [None]:
owtc_len = 8_013_769
owtc_dir = DATA_DIR / 'openwebtext_shards'
owtc_iter = make_corpus_iter(owtc_dir)

In [None]:
owtc_minhashes = {}

# Create MinHashes for OWTC
with mp.Pool(96) as pool:
    mh_iter = pool.imap(make_minhash_, owtc_iter, chunksize=1000)
    for key, minhash in tqdm(mh_iter, total=owtc_len):
        owtc_minhashes[key] = minhash

In [None]:
# TEST
test_id, test_mh = make_minhash_(next(owtc_iter))

In [None]:
test_id, test_mh

In [None]:
test_matches = lsh.query(test_mh)

In [None]:
len(test_matches)