In [None]:
from collections import Counter
from typing import Dict, Tuple, List, Iterable, Sequence, Union
import time
import subprocess
from itertools import groupby
from operator import itemgetter

import numpy as np
import wikipedia
from scipy.sparse import csc_matrix, spmatrix, csr_matrix, save_npz, load_npz
from scipy.spatial.distance import cosine

from tqdm.notebook import tqdm

from scipy.spatial import KDTree
from sparselsh import LSH

# Utilities

In [None]:
def count_lines(path: str) -> int:
    command = f"wc -l \"{path}\""
    return int(subprocess.check_output(command, shell=True).split()[0])


### Process wiki token counts

In [None]:
WikiTokenCount = Tuple[str, str, int]

def parse_wiki_token_count_line(line: str) -> WikiTokenCount:
    wid, token, count = line.strip().split(',')
    return wid, token, int(count)


def iter_wiki_counts_file(path: str, n_lines: int = None) -> Iterable[WikiTokenCount]:
    if n_lines is None:
        print("compute number of lines in file")
        n_lines = count_lines(path)
        print(f"counted {n_lines} lines in file: {path}")

    with open(path, 'r', buffering=1024**2) as f:
        lines_iter = tqdm(f, total=n_lines)
        yield from map(parse_wiki_token_count_line, lines_iter)


def process_wiki_counts(wiki_counts: List[WikiTokenCount], min_count: int = 5) -> Tuple[Dict[str, int], List[int], int]:
    tokens_indices: Dict[str, int] = {}
    tokens = []
    tokens_docs_count: List[int] = []     # docs_counts_by_token_index
    n_docs = 0
    for doc_id, counts in groupby(wiki_counts, key=itemgetter(0)):
        n_docs += 1
        for (_, token, count) in counts:
            token_i = tokens_indices.setdefault(token, len(tokens_indices))
            if token_i == len(tokens_docs_count):
                tokens_docs_count.append(1)
                tokens.append(token)
                continue

            tokens_docs_count[token_i] += 1

    freq_token_indices = [i for i, token_count in enumerate(tokens_docs_count) if token_count >= min_count]
    tokens = [tokens[i] for i in freq_token_indices]
    tokens_docs_count = [tokens_docs_count[i] for i in freq_token_indices]
    tokens_indices = {token: i for i, token in enumerate(tokens)}

    return tokens_indices, tokens_docs_count, n_docs


def filter_tokens_by_min_count(
        tokens_indices: Dict[str, int],
        tokens_docs_count: List[int],
        min_count: int = 5
) -> Tuple[Dict[str, int], List[int]]:
    tokens = map(itemgetter(0), sorted(tokens_indices.items(), key=itemgetter(1)))

    freq_token_indices = [i for i, token_count in enumerate(tokens_docs_count) if token_count >= min_count]
    tokens = [tokens[i] for i in freq_token_indices]
    tokens_docs_count = [tokens_docs_count[i] for i in freq_token_indices]
    tokens_indices = {token: i for i, token in enumerate(tokens)}

    return tokens_indices, tokens_docs_count


In [None]:
wiki_count_path = "/Users/ronpick/studies/stance/data/wikipedia-tfidf/tfidf/wp_word_count.csv"
# # wiki_idf_path = "/Users/ronpick/studies/stance/data/wikipedia-tfidf/tfidf/wp_word_idfs.csv"
#
n_lines = count_lines(wiki_count_path)
# wiki_lines = iter_wiki_counts_file(wiki_count_path, n_lines=n_lines)
# token2index, tokens_docs_count, n_docs = process_wiki_counts(wiki_lines)
# print(f"Finished parsing {n_docs} docs, with {len(token2index)} unique tokens")

In [None]:
def save_token_doc_counts(tokens_indices: Dict[str, int], tokens_docs_count: List[int], n_docs: int, path: str):
    tokens = list(map(itemgetter(0), sorted(tokens_indices.items(), key=itemgetter(1))))
    with open(path, 'w') as f:
        f.write(f"*,{n_docs}")
        for (token, count) in zip(tokens, tokens_docs_count):
            f.write(f"\n{token},{count}")


def load_token_doc_counts(path: str) -> Tuple[Dict[str, int], List[int], int]:
    with open(path, 'r') as f:
        n_docs = int(next(f).split(',')[1])
        split_lines = map(lambda l: l.split(','), f)
        tokens, counts = list(zip(*[(token, int(count)) for token, count in split_lines]))
        tokens_indices = {token: i for i, token in enumerate(tokens)}
        return tokens_indices, list(counts), n_docs

In [None]:
token_df_path = "tokens_df.csv"

In [None]:
# save_token_doc_counts(token2index, tokens_docs_count, n_docs, token_df_path)

In [None]:
token2index, tokens_docs_count, n_docs = load_token_doc_counts(token_df_path)
tokens = list(map(itemgetter(0), sorted(token2index.items(), key=itemgetter(1))))
print(f"Loaded data for {n_docs} docs, with {len(tokens)} unique tokens")

### Calculating IDF values
idf value for each token

In [None]:
docs_counts = np.array(tokens_docs_count).astype(np.float32)
idf = np.log(n_docs * np.reciprocal(docs_counts)).reshape(1, -1)
idf

## compute TF-IDF vectors into a sparse matrix
- Compute Count vectors of tokens for each document
- multiply each count vector by the idf values (element-wise)
- Do that in batches for memory efficiency

In [None]:
def iter_count_vectors_batches(
        wiki_counts: List[WikiTokenCount],
        tokens_indices: Dict[str, int],
        batch_size: int = 10_000,
        norm: bool = True
) -> Tuple[csc_matrix, List[str]]:
    n_tokens = len(tokens_indices)  # number of cols
    docs, tokens, data, doc_ids = [], [], [], []
    for doc_id, counts in groupby(wiki_counts, key=itemgetter(0)):

        if len(doc_ids) == batch_size:
            yield csc_matrix((data, (docs, tokens)), shape=(batch_size, n_tokens)), doc_ids
            docs, tokens, data, doc_ids = [], [], [], []

        doc_ids.append(doc_id)
        current_counts = []
        for _, token, count in counts:
            token_index = tokens_indices.get(token)
            if token_index is not None:
                docs.append(len(doc_ids) - 1)
                tokens.append(token_index)
                current_counts.append(count)

        sum_counts = sum(current_counts) if norm else 1
        [data.append(c / sum_counts) for c in current_counts]

    yield csc_matrix((data, (docs, tokens)), shape=(len(doc_ids), n_tokens)), doc_ids


def compute_tfidf_batches(
        count_vectors_batches: Iterable[Tuple[spmatrix, List[str]]],
        idfs: np.ndarray
) -> Iterable[Tuple[spmatrix, List[str]]]:
    for count_vectors, doc_ids in count_vectors_batches:
        yield count_vectors.multiply(idfs), doc_ids


In [None]:
# batch_size = n_docs + 1
# count_lines = iter_wiki_counts_file(wiki_count_path, n_lines=n_lines)
# count_vecs_batches = iter_count_vectors_batches(count_lines, token2index, batch_size=batch_size)
# tfidf_batches = compute_tfidf_batches(count_vecs_batches, idf)

In [None]:
# tfidf, all_doc_ids = next(tfidf_batches)
# print("convert to csr format")
# tfidf = tfidf.tocsr()


In [None]:
def save_doc_ids(doc_ids: List[str], path: str):
    with open(path, 'w') as f:
        for doc_id in doc_ids:
            f.write(f"{doc_id}\n")


def load_doc_ids(path: str) -> List[str]:
    with open(path, 'r') as f:
        return list(map(str.strip, f))


def save_sparse_matrix(path: str, m: csr_matrix):
    np.savez_compressed(path, indptr=m.indptr, indices=m.indices, data=m.data)


def load_sparse_matrix(path: str, shape: Tuple[int, int]) -> csr_matrix:
    arrays = np.load(path, allow_pickle=False)
    indptr = arrays["indptr"]
    indices = arrays["indices"]
    data = arrays["data"]
    return csr_matrix((data, indices, indptr), shape=shape)

In [None]:
doc_ids_path = "docids.txt"
docs_repr_path = "docs_repr.npz"

In [None]:
# save_doc_ids(all_doc_ids, doc_ids_path)
# save_sparse_matrix(docs_repr_path, tfidf)

In [None]:
all_doc_ids = load_doc_ids(doc_ids_path)

In [None]:
tfidf = load_sparse_matrix(docs_repr_path, shape=(n_docs, len(token2index)))
tfidf


In [None]:
# from sklearn.preprocessing import normalize
#
# tfidf_norm = normalize(tfidf, norm='l2', axis=1)

# Perfrom LSH

In [None]:
def generate_hyperplanes(n_hashes: int, rng: np.random.Generator) -> csr_matrix:
    total_size = n_hashes * len(token2index)
    planes = np.ones(total_size)
    planes[:total_size//2] = -1
    planes = rng.permutation(planes).reshape(n_hashes, len(token2index))
    return csr_matrix(planes)

def compute_signatures(m: spmatrix, planes: np.ndarray) -> np.ndarray:
    return (planes.dot(m.transpose()).toarray() > 0).astype(np.uint8)


def compute_buckets(signatures: np.ndarray, hashes_per_bucket: int = 20) -> np.ndarray:
    n_hashes, n_docs = signatures.shape
    buckets = []
    i = 0
    while i < n_hashes:
        sub_signatures = signatures[i: i + hashes_per_bucket]
        print(f"Calculate buckets for signatures of length {len(sub_signatures)}: {i} - {i + hashes_per_bucket}")
        packed = np.zeros(n_docs, dtype=np.uint32)
        for j in range(len(sub_signatures)):
            packed = packed << 1 | sub_signatures[j]

        buckets.append(packed)
        i += hashes_per_bucket

    buckets = np.vstack(buckets)
    return buckets


def create_hashtables(buckets_per_doc: Sequence[Sequence[int]]) -> List[Dict[int, List[int]]]:
    all_buckets_mapping = []
    for buckets in buckets_per_doc:
        buckets_map = {}
        for doc_i, h in enumerate(buckets):
            docs = buckets_map.setdefault(h, [])
            docs.append(doc_i)

        all_buckets_mapping.append(buckets_map)

    return all_buckets_mapping

### Crteate random hyperplanes

In [None]:
# num_hashes = 100
# rng = np.random.default_rng(1919)
# planes = generate_hyperplanes(num_hashes, rng)
# planes

### calculate docs signatures

In [None]:
# print("calculating hashes")
# start = time.time()
# signatures = compute_signatures(tfidf, planes)
# duration = time.time() - start
# print(f"Done - [took {duration} sec]")
# signatures

### Store LSH data

In [None]:
lsh_data_path = "lsh.npz"

In [None]:
# np.savez_compressed(lsh_data_path, planes=planes.toarray(), signatures=signatures)

In [None]:
data = np.load(lsh_data_path)
planes = csr_matrix(data["planes"])
signatures = data["signatures"]

### pack signatures to int hashes

In [None]:
buckets_per_doc = compute_buckets(signatures, hashes_per_bucket=10)
buckets_per_doc

### Create hash-tables for each

In [None]:
%time
hashtables = create_hashtables(buckets_per_doc)
len(hashtables)
for ht in hashtables:
    print(len(ht))

In [None]:
nn

In [None]:
page = wikipedia.page(pageid=all_doc_ids[232251])
print(page.pageid, page.title)
#
# page = wikipedia.page(pageid=all_doc_ids[2537209])
# print(page.pageid, page.title)

In [None]:
r1 = tfidf[36]
print(r1.indices)
r2 = tfidf[2537209]
print(r2.indices)

In [None]:
tokens = list(map(itemgetter(0), sorted(token2index.items(), key=itemgetter(1))))

In [None]:
print(tokens_docs_count[8264])
tokens[8264], np.log(n_docs / tokens_docs_count[8264])

In [None]:
r1.toarray()[0][8264], r2.toarray()[0][8264]

In [None]:
# hashed_batches = []
# all_doc_ids = []
# for tfidf, doc_ids in tfidf_batches:
#     tfidf_csr = tfidf.tocsr()
#     hashed_docs = hash_docs(tfidf_csr, lsh)
#     hashed_batches.append(hashed_docs)
#     all_doc_ids.extend(doc_ids)
#
# print("concat data")
# data = np.vstack(hashed_batches)
# data.shape


In [None]:
from annoy import AnnoyIndex

n_docs = data.shape[0]
dim = data.shape[1]
tree = AnnoyIndex(dim, "hamming")
for i in tqdm(range(n_docs), total=n_docs):
    tree.add_item(i, data[i])

print("Building annoy tree")
tree.build(100)


In [None]:
ascii_lower_start = 97
ascii_lower_end = 122

def is_ascii_letter(c: str) -> bool:
    return ascii_lower_start <= ord(c) <= ascii_lower_end


def get_longest_letters_sequence(token: str) -> str:
    token = token.lower()
    end, best_len, current_len = 0, 0, 0
    for i in range(len(token)):
        if not is_ascii_letter(token[i]):
            if current_len > best_len:
                end = i
                best_len = current_len

            current_len = 0
            continue

        current_len += 1

    if current_len > best_len:
        return token[-current_len:]
    return token[end-best_len: end].lower()


def is_valid_token(token: str) -> bool:
    return len(token) > 2


def tokenize_clean_text(text: str) -> List[str]:
    tokens = map(get_longest_letters_sequence, text.split())
    return list(filter(is_valid_token, tokens))


def get_tfidf_vector(text: str, token2index: Dict[str, int], idf: np.ndarray, ) -> csr_matrix:
    tokens = tokenize_clean_text(text)
    counts = Counter(filter(bool, map(token2index.get, tokens)))
    rows = np.zeros(len(counts))
    cols, data = zip(*sorted(counts.items(), key=itemgetter(0)))
    counts_vec = csc_matrix((data, (rows, cols)), shape=(1, len(token2index)))
    return counts_vec.multiply(idf).tocsr()


In [None]:
text = "Guns should be banned because they are not needed in any domestic issue." \
       "The second ammendment was put in place because of fear that the british might invade america again or take " \
       "control of the government. " \
       "if this were the case the people would need weapons to defend themselves and regain america. " \
       "The british aren't going to invade so we don't need to protect our selves. " \
       "even in the this day and age america remains increadible safe compared to many other nations. " \
       "we have no close enemies. " \
       "if a major army were to attack us a few men with pistols or shotguns wouldn't do much against a soldier with " \
       "an ak47 or tanks or bombersGuns in America just make it easier for crimes to be committed. " \
       "Some guns should never be considered allowed and this includes all semi automatic weapons as well as " \
       "shotgunsPoverty, drugs, and lack of education are the reasons people turn to guns to kill. " \
       "guns give you power to take life and should not be allowed to float around so that our students or citizens " \
       "can use them against one another"

text2 = """Vasily Ivanovich Chuikov (Russian: Васи́лий Ива́нович Чуйко́в; About this soundlisten (help·info); 12 February [O.S. 31 January] 1900 – 18 March 1982) was a Soviet military commander and Marshal of the Soviet Union. He is best known for commanding the 62nd Army which saw heavy combat during the Battle of Stalingrad in the Second World War.

Born to a peasant family near Tula, Chuikov earned his living as a factory worker from the age of 12. After the Russian Revolution of 1917, he joined the Red Army and distinguished himself during the Russian Civil War. After graduating from the Frunze Military Academy, Chuikov worked as a military attaché and intelligence officer in China and the Russian Far East. At the outbreak of the Second World War, Chuikov commanded the 4th Army during the Soviet invasion of Poland, and the 9th Army during the Winter War against Finland. In December 1940, he was again appointed military attaché to China in support of Chiang Kai-shek and the Nationalists in the war against Japan.

In March 1942, Chuikov was recalled from China to command the 62nd Army in defense of Stalingrad. Tasked with holding the city at all costs, Chuikov adopted the hugging tactic, keeping the Soviet front-line positions as close to the Germans as physically possible. This served as an effective countermeasure against the Wehrmacht's combined-arms tactics, but by mid-November 1942 the Germans had captured most of the city after months of slow advance. In late November Chuikov's 62nd Army joined the rest of the Soviet forces in a counter-offensive, which led to the surrender of the German army in early 1943. After Stalingrad, Chuikov led his forces into Poland during Operation Bagration and the Vistula–Oder Offensive before advancing on Berlin. He personally accepted the unconditional surrender of German forces in Berlin on 2 May 1945.

After the war, Chuikov served as Chief of the Group of Soviet Forces in Germany (1949–53), commander of the Kiev Military District (1953–60), Chief of the Soviet Armed Forces and Deputy Minister of Defense (1960–64), and head of the Soviet Civil Defense Forces (1961–72). Chuikov was twice awarded the titles Hero of the Soviet Union (1944 and 1945) and was awarded the Distinguished Service Cross by the United States for his actions during the Battle of Stalingrad. In 1955, he was named a Marshal of the Soviet Union. Following his death in 1982, Chuikov was interred at the Stalingrad memorial at Mamayev Kurgan, which had been the site of heavy fighting."""

text3 = """Mark Rutherford School is a mixed secondary school and sixth form in Bedford, England. The school is named in honour of the Bedford-born writer William Hale White (1831-1913), who used Mark Rutherford as a pseudonym.

Mark Rutherford school educates pupils from age 11 through to 16. In addition, the school offers a sixth form provision for pupils age 16 to 19 wishing to study courses such as A levels. The school has a specialism in performing arts, and offers a range of courses related to the specialism."""

text4 = """Abortion is the ending of a pregnancy by removal or expulsion of an embryo or fetus.[note 1] An abortion that occurs without intervention is known as a miscarriage or "spontaneous abortion" and occurs in approximately 30% to 40% of pregnancies.[1][2] When deliberate steps are taken to end a pregnancy, it is called an induced abortion, or less frequently "induced miscarriage". The unmodified word abortion generally refers to an induced abortion.[3][4]

When properly done, abortion is one of the safest procedures in medicine,[5]:1 [6]:1 but unsafe abortion is a major cause of maternal death, especially in the developing world,[7] while making safe abortion legal and accessible reduces maternal deaths.[8][9] It is safer than childbirth, which has a 14 times higher risk of death in the United States.[10]

Modern methods use medication or surgery for abortions.[11] The drug mifepristone in combination with prostaglandin appears to be as safe and effective as surgery during the first and second trimester of pregnancy.[11][12] The most common surgical technique involves dilating the cervix and using a suction device.[13] Birth control, such as the pill or intrauterine devices, can be used immediately following abortion.[12] When performed legally and safely on a woman who desires it, induced abortions do not increase the risk of long-term mental or physical problems.[14] In contrast, unsafe abortions (those performed by unskilled individuals, with hazardous equipment, or in unsanitary facilities) cause 47,000 deaths and 5 million hospital admissions each year.[14][15] The World Health Organization states that "access to legal, safe and comprehensive abortion care, including post-abortion care, is essential for the attainment of the highest possible level of sexual and reproductive health".[16]

Around 56 million abortions are performed each year in the world,[17] with about 45% done unsafely.[18] Abortion rates changed little between 2003 and 2008,[19] before which they decreased for at least two decades as access to family planning and birth control increased.[20] As of 2018, 37% of the world's women had access to legal abortions without limits as to reason.[21][22] Countries that permit abortions have different limits on how late in pregnancy abortion is allowed.[22] Abortion rates are similar between countries that ban abortion and countries that allow it.[23]

Historically, abortions have been attempted using herbal medicines, sharp tools, forceful massage, or through other traditional methods.[24] Abortion laws and cultural or religious views of abortions are different around the world. In some areas abortion is legal only in specific cases such as rape, problems with the fetus, poverty, risk to a woman's health, or incest.[25] There is debate over the moral, ethical, and legal issues of abortion.[26][27] Those who oppose abortion often argue that an embryo or fetus is a human with a right to life, and they may compare abortion to murder.[28][29] Those who support the legality of abortion often hold that it is part of a woman's right to make decisions about her own body.[30] Others favor legal and accessible abortion as a public health measure.[31]"""

text5 = """Same-sex marriage, also known as gay marriage, is the marriage of two people of the same sex or gender, entered into in a civil or religious ceremony. There are records of same-sex marriage dating back to the first century. In the modern era, marriage equality was first granted to same-sex couples in the Netherlands on 1 April 2001.

As of 2021, same-sex marriage is legally performed and recognized in 29 countries (nationwide or in some jurisdictions):

Argentina
Australia
Austria
Belgium
Brazil
Canada
Colombia
Costa Rica
Denmark
Ecuador
Finland
France
Germany
Iceland
Ireland
Luxembourg
Malta
Mexico[a]
Netherlands[b]
New Zealand[c]
Norway
Portugal
South Africa
Spain
Sweden
Taiwan
United Kingdom[d]
United States[e]
Uruguay
The introduction of same-sex marriage (also called marriage equality) has varied by jurisdiction, and came about through legislative change to marriage law, court rulings based on constitutional guarantees of equality, recognition that it is allowed by existing marriage law,[1] or by direct popular vote (via referendums and initiatives). The recognition of same-sex marriage is considered to be a human right and a civil right as well as a political, social, and religious issue.[2] The most prominent supporters of same-sex marriage are human rights and civil rights organizations as well as the medical and scientific communities, while the most prominent opponents are religious fundamentalist groups. Polls consistently show continually rising support for the recognition of same-sex marriage in all developed democracies and in some developing democracies.

Scientific studies show that the financial, psychological, and physical well-being of gay people are enhanced by marriage, and that the children of same-sex parents benefit from being raised by married same-sex couples within a marital union that is recognized by law and supported by societal institutions.[3] Social science research indicates that the exclusion of homosexuals from marriage stigmatizes and invites public discrimination against them, with research also repudiating the notion that either civilization or viable social orders depend upon restricting marriage to heterosexuals.[4] Same-sex marriage can provide those in committed same-sex relationships with relevant government services and make financial demands on them comparable to that required of those in opposite-sex marriages, and also gives them legal protections such as inheritance and hospital visitation rights.[5] Opposition to same-sex marriage is based on claims such as that homosexuality is unnatural and abnormal, that the recognition of same-sex unions will promote homosexuality in society, and that children are better off when raised by opposite-sex couples.[6] These claims are refuted by scientific studies, which show that homosexuality is a natural and normal variation in human sexuality, and that sexual orientation is not a choice. Many studies have shown that children of same-sex couples fare just as well as the children of opposite-sex couples; some studies have shown benefits to being raised by same-sex couples.[7]

A study of nationwide data from across the United States from January 1999 to December 2015 revealed that the establishment of same-sex marriage is associated with a significant reduction in the rate of attempted suicide among children, with the effect being concentrated among children of a minority sexual orientation, resulting in about 134,000 fewer children attempting suicide each year in the United States.[8]"""

In [None]:
text_repr = get_tfidf_vector(text5, token2index, idf)
cols = text_repr.indices
t = [tokens[col] for col in cols]
sorted((zip(t, text_repr.data)), key=itemgetter(1), reverse=True)



In [None]:
test_signatures = compute_signatures(text_repr, planes)
test_buckets = compute_buckets(test_signatures, hashes_per_bucket=10)
test_buckets

In [None]:
docs_counts = Counter()
for i, b in enumerate(test_buckets):
    docs_counts.update(hashtables[i].get(b[0], []))

docs_counts.most_common(10)

In [None]:
nn = list(map(itemgetter(0), docs_counts.most_common(10)))
distances = [(i, text_repr.dot(tfidf[i].transpose()).toarray()[0][0], cosine(text_repr.toarray(), tfidf[i].toarray())) for i in nn]
distances = sorted(distances, key=itemgetter(1), reverse=True)
for d in distances:
    print(d)

In [None]:
for i, dot, cos in distances:
    pageid = all_doc_ids[i]
    try:
        page = wikipedia.page(pageid=pageid)
        print(i, page.pageid, page.title)
        print(cos, dot)
    except wikipedia.DisambiguationError as e:
        print(pageid, e.options)
    except wikipedia.PageError:
        print(pageid, "Doesn't exist")
    print()


In [None]:
# encoded_docs = encode_docs(tfidf_csr, lsh)
# start = time.time()
# tfidf_dense = tfidf.toarray()
# for i, (doc_hash, doc_id) in tqdm(enumerate(zip(doc_ids, doc_ids)), total=len(doc_ids), leave=False):
#     storage.append_val(doc_hash, (tfidf_dense[0], doc_id))
# duration = time.time() - start
# print(f"done encoding {batch_size} docs into lsh [{duration} sec]")