In [8]:
# HACK: use project root as the working directory 
from pathlib import Path

while Path.cwd().name != 'language-model-toxicity':
    %cd ..

In [9]:
from typing import List
import logging
from pathlib import Path

import dask
import dask.array as da
import numpy as np
from tqdm.auto import tqdm

from utils.constants import DATA_DIR

logging.disable(logging.CRITICAL)  # Disable logging from transformers

In [14]:
EOS = 50256

def split_docs(tokens: np.array) -> np.array:
    idx = np.nonzero(tokens == EOS)[0]
    docs = np.split(tokens, idx)
    docs = [doc[1:] for doc in docs if len(doc) > 1]
    
    # Pad all arrays
    max_len = max(map(len, docs))
    
    for i in range(len(docs)):
        doc = docs[i]
        padded_doc = np.pad(doc, (0, max_len - len(doc)))
        docs[i] = padded_doc

    return np.stack(padded_docs)

def load_meta(bpe_dir: Path, files_only=False):
    files = [file for file in bpe_dir.iterdir() if file.suffix == '.npy']
    if files_only:
        return files
    meta = [(np.count_nonzero(array == EOS) - 1, array.dtype)
            for array 
            in tqdm(map(np.load, files), total=len(files), desc='Loading meta')]
    shapes, dtypes = zip(*meta)
    return files, shapes, dtypes[0]

## Load Corpora

In [11]:
wt_dir = DATA_DIR / 'webtext'
# wt_meta = load_meta(wt_dir)
wt_files = load_meta(wt_dir, files_only=True)

In [12]:
owtc_dir = DATA_DIR / 'openwebtext_bpe'
owtc_files = load_meta(owtc_dir, files_only=True)

In [None]:
# owtc_corpus = []
# for shard in tqdm(map(np.load, owtc_files), total=len(owtc_files)):
#     owtc_corpus.extend(split_docs(shard))

## Load shards

In [None]:
delayed_load = dask.delayed(lambda f: split_docs(np.load(f)))

def load_corpus(meta):
    files, shapes, dtype = meta
    
    # Create delayed arrays
    delayed_arrays = list(map(delayed_load, files))
        
    # Concatenate arrays
    corpus = da.concatenate([da.from_delayed(array, shape=(shape,), dtype=dtype) 
                             for array, shape in zip(delayed_arrays, shapes)])

    return corpus

In [None]:
wt_corpus = load_corpus(wt_meta)
wt_corpus

## Pairwise

In [None]:
client = Client()
client

In [None]:
from dask.distributed import Client
import joblib
# from sklearn.metrics import pairwise_distances
from dask_ml.metrics.pairwise import pairwise_distances

with joblib.parallel_backend('dask'):
    pairwise_distances(wt_corpus, owtc_corpus)

## Another test

In [None]:
wt_docs = split_docs(np.load(wt_files[0]))

In [None]:
owtc_docs = split_docs(np.load(owtc_files[0]))

In [None]:
wt_docs[0]

In [27]:
from sklearn.metrics import pairwise_distances, pairwise_distances_chunked

In [28]:
import editdistance

In [29]:
gen = pairwise_distances_chunked(wt_docs, owtc_docs, metric=editdistance.eval, n_jobs=8)