In [3]:
from collections import Counter
from typing import Dict, Tuple, List, Iterable, Sequence
import time
import subprocess
from itertools import groupby
from operator import itemgetter

import numpy as np
from scipy.sparse import csc_matrix, spmatrix

from tqdm.notebook import tqdm

from scipy.spatial import KDTree
from sparselsh import LSH

# Utilities

In [4]:
def count_lines(path: str) -> int:
    command = f"wc -l \"{path}\""
    return int(subprocess.check_output(command, shell=True).split()[0])


### Process wiki token counts

In [5]:
WikiTokenCount = Tuple[str, str, int]

def parse_wiki_token_count_line(line: str) -> WikiTokenCount:
    wid, token, count = line.strip().split(',')
    return wid, token, int(count)


def iter_wiki_counts_file(path: str, n_lines: int = None) -> Iterable[WikiTokenCount]:
    if n_lines is None:
        print("compute number of lines in file")
        n_lines = count_lines(path)
        print(f"counted {n_lines} lines in file: {path}")

    with open(path, 'r', buffering=1024**2) as f:
        lines_iter = tqdm(f, total=n_lines)
        yield from map(parse_wiki_token_count_line, lines_iter)


def process_wiki_counts(wiki_counts: List[WikiTokenCount], min_count: int = 5) -> Tuple[Dict[str, int], List[int], int]:
    tokens_indices: Dict[str, int] = {}
    tokens = []
    tokens_docs_count: List[int] = []     # docs_counts_by_token_index
    n_docs = 0
    for doc_id, counts in groupby(wiki_counts, key=itemgetter(0)):
        n_docs += 1
        for (_, token, count) in counts:
            token_i = tokens_indices.setdefault(token, len(tokens_indices))
            if token_i == len(tokens_docs_count):
                tokens_docs_count.append(1)
                tokens.append(token)
                continue

            tokens_docs_count[token_i] += 1

    freq_token_indices = [i for i, token_count in enumerate(tokens_docs_count) if token_count >= min_count]
    tokens = [tokens[i] for i in freq_token_indices]
    tokens_docs_count = [tokens_docs_count[i] for i in freq_token_indices]
    tokens_indices = {token: i for i, token in enumerate(tokens)}

    return tokens_indices, tokens_docs_count, n_docs

In [6]:
wiki_count_path = "/Users/ronpick/studies/stance/data/wikipedia-tfidf/tfidf/wp_word_count.csv"
# wiki_idf_path = "/Users/ronpick/studies/stance/data/wikipedia-tfidf/tfidf/wp_word_idfs.csv"

n_lines = count_lines(wiki_count_path)
count_lines = iter_wiki_counts_file(wiki_count_path, n_lines=n_lines)
token2index, tokens_docs_count, n_docs = process_wiki_counts(count_lines)
print(f"Finished parsing {n_docs} docs, with {len(token2index)} unique tokens")


HBox(children=(FloatProgress(value=0.0, max=458462175.0), HTML(value='')))


Finished parsing 4334806 docs, with 918912 unique tokens


In [7]:
def filter_tokens_by_min_count(
        tokens_indices: Dict[str, int],
        tokens_docs_count: List[int],
        min_count: int = 5
) -> Tuple[Dict[str, int], List[int]]:
    tokens = map(itemgetter(0), sorted(tokens_indices.items(), key=itemgetter(1)))

    freq_token_indices = [i for i, token_count in enumerate(tokens_docs_count) if token_count >= min_count]
    tokens = [tokens[i] for i in freq_token_indices]
    tokens_docs_count = [tokens_docs_count[i] for i in freq_token_indices]
    tokens_indices = {token: i for i, token in enumerate(tokens)}

    return tokens_indices, tokens_docs_count

### Compute Count vectors of tokens for each document

In [8]:
def iter_count_vectors_batches(
        wiki_counts: List[WikiTokenCount],
        tokens_indices: Dict[str, int],
        batch_size: int = 10_000
) -> Tuple[csc_matrix, List[str]]:
    n_tokens = len(tokens_indices)  # number of cols
    docs, tokens, data, doc_ids = [], [], [], []
    for doc_id, counts in groupby(wiki_counts, key=itemgetter(0)):

        if len(doc_ids) == batch_size:
            yield csc_matrix((data, (docs, tokens)), shape=(batch_size, n_tokens)), doc_ids
            docs, tokens, data, doc_ids = [], [], [], []

        doc_ids.append(doc_id)
        for _, token, count in counts:
            token_index = tokens_indices.get(token)
            if token_index is not None:
                docs.append(len(doc_ids) - 1)
                tokens.append(token_index)
                data.append(count)

    yield csc_matrix((data, (docs, tokens)), shape=(len(doc_ids), n_tokens)), doc_ids

# compute TF-IDF vectors into a sparse matrix

In [9]:
docs_counts = np.array(tokens_docs_count).astype(np.float32)
idf = np.log(n_docs * np.reciprocal(docs_counts)).reshape(1, -1)

def compute_tfidf_batches(
        count_vectors_batches: Iterable[Tuple[spmatrix, List[str]]],
        docs_counts: List[int],
        n_docs: int
) -> Iterable[Tuple[spmatrix, List[str]]]:
    docs_counts = np.array(docs_counts).astype(np.float32)
    idf = np.log(n_docs * np.reciprocal(docs_counts)).reshape(1, -1)
    for count_vectors, doc_ids in count_vectors_batches:
        yield count_vectors.multiply(idf), doc_ids

In [10]:
batch_size = 50_000
count_lines = iter_wiki_counts_file(wiki_count_path, n_lines=n_lines)
count_vecs_batches = iter_count_vectors_batches(count_lines, token2index, batch_size=batch_size)
tfidf_batches = compute_tfidf_batches(count_vecs_batches, tokens_docs_count, n_docs=n_docs)

# Perfrom LSH

In [11]:
lsh = LSH(
    8,
    len(token2index),
    num_hashtables=1,
    storage_config={"dict":None}
)

lsh

<sparselsh.lsh.LSH at 0x1565ce250>

In [12]:
def hash_docs(m: spmatrix, lsh: LSH) -> np.ndarray:
    planes = lsh.uniform_planes[0]
    return planes.dot(m.T).T.toarray() > 0


def encode_docs(m: spmatrix, lsh: LSH) -> Sequence[int]:
    planes = lsh.uniform_planes[0]
    hashed = planes.dot(m.T).toarray() > 0
    packed = np.zeros(hashed.shape[1], dtype=np.uint32)
    for i in range(hashed.shape[0]):
        packed = packed << 1 | hashed[i]

    return packed

In [13]:
# tfidf, doc_ids = next(tfidf_batches)
# tfidf

In [14]:
# planes = lsh.uniform_planes[0]
# hashed = planes.dot(tfidf.T).toarray() > 0
# hashed
# packed = np.zeros(hashed.shape[1], dtype=np.uint32)
# print(packed.shape)
# for i in range(hashed.shape[0]):
#     packed = packed << 1 | hashed[i]
#
# packed

In [15]:
hashed_batches = []
all_doc_ids = []
for tfidf, doc_ids in tfidf_batches:
    tfidf_csr = tfidf.tocsr()
    hashed_docs = hash_docs(tfidf_csr, lsh)
    hashed_batches.append(hashed_docs)
    all_doc_ids.extend(doc_ids)

print("concat data")
data = np.vstack(hashed_batches)
data.shape

HBox(children=(FloatProgress(value=0.0, max=458462175.0), HTML(value='')))


concat data


(4334806, 8)

In [16]:
data = data.astype(np.uint8)
data

array([[0, 1, 1, ..., 1, 0, 0],
       [1, 0, 1, ..., 1, 0, 1],
       [0, 0, 0, ..., 1, 1, 1],
       ...,
       [1, 1, 0, ..., 1, 1, 1],
       [1, 1, 0, ..., 1, 0, 0],
       [1, 1, 0, ..., 0, 1, 0]], dtype=uint8)

In [None]:
from annoy import AnnoyIndex

tree = AnnoyIndex(8, "hamming")
for i in tqdm(range(data.shape[0]), total=data.shape[0]):
    tree.add_item(i, data[i])

print("Building annoy tree")
tree.build(100)

In [18]:
import sys
sys.setrecursionlimit(10_000)

In [20]:
%time
print("Create KD-Tree")
tree = KDTree(data, leafsize=300)

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 4.77 µs
Create KD-Tree


RecursionError: maximum recursion depth exceeded while calling a Python object

In [None]:
ascii_upper_start = 65
ascii_upper_end = 90
ascii_lower_start = 97
ascii_lower_end = 122

def is_ascii_letter(c: str) -> bool:
    i = ord(c)
    if ascii_upper_start <= i <= ascii_upper_end:
        return True

    return ascii_lower_start <= i <= ascii_lower_end


def get_longest_letters_sequence(token: str) -> str:
    end, best_len, current_len = 0, 0, 0
    for i in range(len(token)):
        if not is_ascii_letter(token[i]):
            if current_len > best_len:
                end = i
                best_len = current_len

            current_len = 0
            continue

        current_len += 1

    return token[end-best_len: end]


def is_valid_token(token: str) -> bool:
    return len(token) > 2


def tokenize_clean_text(text: str) -> List[str]:
    tokens = map(get_longest_letters_sequence, text.split())
    return list(filter(is_valid_token, tokens))


def get_tfidf_vector(text: str, token2index: Dict[str, int], idf: np.ndarray, ) -> spmatrix:
    tokens = tokenize_clean_text(text)
    counts = Counter(filter(bool, map(token2index.get, tokens)))
    rows = np.zeros(len(counts))
    cols, data = sorted(counts.items(), key=itemgetter(0))
    counts_vec = csc_matrix((data, (rows, cols)), shape=(1, len(token2index)))
    return counts_vec.multiply(idf)



In [None]:
text = "Guns should be banned because they are not needed in any domestic issue." \
       "The second ammendment was put in place because of fear that the british might invade america again or take " \
       "control of the government. " \
       "if this were the case the people would need weapons to defend themselves and regain america. " \
       "The british aren't going to invade so we don't need to protect our selves. " \
       "even in the this day and age america remains increadible safe compared to many other nations. " \
       "we have no close enemies. " \
       "if a major army were to attack us a few men with pistols or shotguns wouldn't do much against a soldier with " \
       "an ak47 or tanks or bombersGuns in America just make it easier for crimes to be committed. " \
       "Some guns should never be considered allowed and this includes all semi automatic weapons as well as " \
       "shotgunsPoverty, drugs, and lack of education are the reasons people turn to guns to kill. " \
       "guns give you power to take life and should not be allowed to float around so that our students or citizens " \
       "can use them against one another"

In [None]:
tokenize_clean_text(text)

In [None]:
text_repr = get_tfidf_vector(text, token2index, idf)
hashed_text = hash_docs(text_repr, lsh)
tree.query(hashed_text, k=10, )

In [None]:
# encoded_docs = encode_docs(tfidf_csr, lsh)
# start = time.time()
# tfidf_dense = tfidf.toarray()
# for i, (doc_hash, doc_id) in tqdm(enumerate(zip(doc_ids, doc_ids)), total=len(doc_ids), leave=False):
#     storage.append_val(doc_hash, (tfidf_dense[0], doc_id))
# duration = time.time() - start
# print(f"done encoding {batch_size} docs into lsh [{duration} sec]")