In [None]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

In [None]:
import logging
from typing import List

import dask
from joblib import Memory, Parallel, delayed, dump
from lsh import cache, minhash
import numpy as np
from itertools import chain, islice

from utils.constants import DATA_DIR, OUTPUT_DIR
from utils.webtext import load_meta, delayed_corpus, split_docs

# Create joblib memory
mem = Memory(OUTPUT_DIR / 'cache' / 'webtext_overlap')

In [None]:
cached_meta = mem.cache(load_meta)

wt_meta = cached_meta(DATA_DIR / 'webtext')
wt_files = wt_meta[0]
owtc_meta = cached_meta(DATA_DIR / 'openwebtext_bpe')
owtc_files = owtc_meta[0]

## Find duplicates with LSH

In [None]:
def train(document_feed, char_ngram=3, seeds=100, bands=5, hashbytes=4, n_jobs=1):
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=char_ngram, hashbytes=hashbytes)
    if seeds % bands != 0:
        raise ValueError('Seeds has to be a multiple of bands. {} % {} != 0'.format(seeds, bands))
    
    out = Parallel(n_jobs=n_jobs, verbose=1, backend='threading')(
        delayed(lambda doc_id, doc: (doc_id, hasher.fingerprint(doc)))(doc_id, doc) 
        for doc_id, doc in document_feed
    )

    lshcache = cache.Cache(num_bands=bands, hasher=hasher)
    for doc_id, fingerprint in out:
        lshcache.add_fingerprint(fingerprint, doc_id=doc_id)
    
    return hasher, lshcache

In [None]:
def corpus_iter(files: List[Path], name: str):
    i = 0
    for file in files:
        print("Loading file:", file)
        shard = np.load(file)
        docs = split_docs(shard)
        for doc in docs:
            # Yield name and doc as 4-byte
            yield (i, name), doc.astype(np.int32).tobytes()
            i += 1

In [None]:
corpus = chain(
    corpus_iter(wt_files, name='wt'),
    corpus_iter(owtc_files, name='owtc')
)
hasher, cache = train(corpus, n_jobs=96)

In [None]:
all_duplicates = cache.get_all_duplicates()

In [None]:
len(all_duplicates)

In [None]:
dump(all_duplicates, 'webtext_dups.joblib')

In [None]:
# dump(cache, 'webtext_cache.joblib', compress='zlib')

## Filter candidates

In [None]:
# Remove duplicates found from the same corpus and ensure all have webtext first to eliminate symmetric duplicates
candidate_duplicates = set((x, y) if x[1] == 'wt' else (y, x) 
                           for x, y in all_duplicates if x[1] != y[1])

In [None]:
len(candidate_duplicates)

In [None]:
filtered_duplicates = cache.filter_candidates(candidate_duplicates, min_jaccard=0.99)

In [None]:
len(filtered_duplicates)  # May want to change parameters to increase recall

## Look up documents in corpora

In [None]:
corpora = {'wt': delayed_corpus(wt_meta), 'owtc': delayed_corpus(owtc_meta)}

In [None]:
matching_docs = []
for i, (x, y) in enumerate(candidate_duplicates):
    docs = tuple(corpora[corpus][idx] for idx, corpus in (x, y))
    matching_docs.append(docs)
    if i > 100:
        break

In [None]:
matching_docs = dask.compute(*matching_docs[:2])

In [None]:
matching_docs[0][0][1]