In [None]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

In [2]:
import logging

from joblib import Memory, Parallel, delayed
from lsh import cache, minhash
from itertools import repeat, chain

from utils.constants import DATA_DIR, OUTPUT_DIR
from utils.webtext import load_meta, delayed_corpus

# Disable logging from transformers
logging.disable(logging.CRITICAL)

# Create joblib memory
mem = Memory(OUTPUT_DIR / 'cache' / 'webtext_overlap')

In [None]:
cached_meta = mem(load_meta)
wt_corpus = delayed_corpus(cached_meta(DATA_DIR / 'webtext'))
owtc_corpus = delayed_corpus(cached_meta(DATA_DIR / 'openwebtext_bpe'))

In [10]:
def train(document_feed, char_ngram=3, seeds=100, bands=5, hashbytes=4, n_jobs=1):
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=char_ngram, hashbytes=hashbytes)
    if seeds % bands != 0:
        raise ValueError('Seeds has to be a multiple of bands. {} % {} != 0'.format(seeds, bands))

    lshcache = cache.Cache(num_bands=bands, hasher=hasher)
    
    hash_bytes = lambda x: hasher.fingerprint(x.tobytes())
    fingerprints = Parallel(n_jobs=n_jobs, verbose=1)(
        delayed(hash_bytes)(doc, doc_id=(i, corpus_name))
        for i, (corpus_name, doc) in enumerate(document_feed)
    )

    # TODO: use document ids
    for i, fingerprint in enumerate(fingerprints):
        lshcache.add_fingerprint(fingerprint, doc_id=i)
    
    return hasher, lshcache

In [42]:
with_corpus_name = lambda corpus, name: zip(repeat(name), corpus)
corpus = chain(
    with_corpus_name(wt_corpus.compute(num_workers=20), 'webtext'),
    with_corpus_name(owtc_corpus.compute(num_workers=20), 'owtc')
)
train(corpus, n_jobs=80)

[Parallel(n_jobs=80)]: Using backend LokyBackend with 80 concurrent workers.
[Parallel(n_jobs=80)]: Done  41 tasks      | elapsed:    0.4s
[Parallel(n_jobs=80)]: Done 422 tasks      | elapsed:    1.0s
[Parallel(n_jobs=80)]: Done 1122 tasks      | elapsed:    1.8s
[Parallel(n_jobs=80)]: Done 2022 tasks      | elapsed:    2.9s
[Parallel(n_jobs=80)]: Done 3122 tasks      | elapsed:    4.2s
[Parallel(n_jobs=80)]: Done 4422 tasks      | elapsed:    5.7s
[Parallel(n_jobs=80)]: Done 5922 tasks      | elapsed:    7.8s
[Parallel(n_jobs=80)]: Done 7622 tasks      | elapsed:    9.9s
[Parallel(n_jobs=80)]: Done 9522 tasks      | elapsed:   12.1s
[Parallel(n_jobs=80)]: Done 11622 tasks      | elapsed:   14.7s
[Parallel(n_jobs=80)]: Done 13922 tasks      | elapsed:   17.5s
[Parallel(n_jobs=80)]: Done 16422 tasks      | elapsed:   20.8s
[Parallel(n_jobs=80)]: Done 19122 tasks      | elapsed:   24.0s
[Parallel(n_jobs=80)]: Done 22022 tasks      | elapsed:   27.5s
[Parallel(n_jobs=80)]: Done 25122 task

KeyboardInterrupt: 