In [15]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from primesieve import *

from numba import njit, jit
from itertools import permutations

import random
import math

In [27]:
%%time

print(1 - np.power((1 - np.power(TRESHOLD, PERMUTATIONS / B)), B))
print(PERMUTATIONS/B)

df = pd.read_csv("news_articles_small.csv")
# df = pd.read_csv("news_articles_large.csv")

print(df.head())

# Jaccard similarity

shingles = SHINGLES
shingles_dict = dict()

# convert articles into sets of shingles.
def apply_shingles(input):
    s = set()
    l = input.split(' ')
    for ind in range((len(l) - shingles + 1)):
        word = ' '.join(l[ind:ind+shingles])
        if word not in shingles_dict:
            shingles_dict[word] = len(shingles_dict)
        s.add(shingles_dict[word])
    return s
        
df["shingles"] = df["article"].apply(apply_shingles)
print(df.head())

# function to calculate jaccard similarity
def similarity(a, b):
    return (len(a & b) / len(a | b))

# calculate similarity between all documents, plot in histogram
similarities = list()
simdict = dict()
no_repeating = set()
for ind_i, i in enumerate(df["shingles"]):
    for ind_j, j in enumerate(df["shingles"]):
        if (ind_j, ind_i) in no_repeating:
            pass
        else:
            no_repeating.add((ind_i, ind_j))
            if ind_i != ind_j:
                sim = similarity(i, j)
                similarities.append(sim)
                if ind_i < ind_j:
                    simdict[(ind_i, ind_j)] = sim
                else:
                    simdict[(ind_j, ind_i)] = sim

hist = np.histogram(similarities, bins=10, range=(0, 1))



B = 1
SHINGLES = 2
PERMUTATIONS = 2 # how many permutations / hashes you want
TRESHOLD = 0.7


0.9983616
1.0
   News_ID                                            article
0        0  The Supreme Court in Johnnesberg on Friday pos...
1        1  The IG Metall union has decided not to spread ...
2        2  Malaysia said Friday it had no plans to overre...
3        3  South Korea is redoubling its efforts behind K...
4        4  The Philippine subsidiary of US telecommunicat...
   News_ID                                            article  \
0        0  The Supreme Court in Johnnesberg on Friday pos...   
1        1  The IG Metall union has decided not to spread ...   
2        2  Malaysia said Friday it had no plans to overre...   
3        3  South Korea is redoubling its efforts behind K...   
4        4  The Philippine subsidiary of US telecommunicat...   

                                            shingles  
0  {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...  
1  {512, 513, 514, 515, 516, 517, 518, 519, 520, ...  
2  {526, 527, 528, 529, 530, 531, 532, 533, 534, ...  
3  

In [46]:
class Hash:
    def __init__(self, a, b, p, len_shingles):
        self.a = a
        self.b = b
        self.p = p
        self.len_shingles = len_shingles

    def __call__(self, *args, **kwargs):
        return ((self.a * args[0] + self.b) % self.p) % self.len_shingles

def make_hashes(num_perm):
    p = n_primes(1,len(all_shingles))[0]
    hashes = []
    len_shingles = len(all_shingles)
    for i in range(num_perm):
        a = np.random.randint(1, p-1)
        b = np.random.randint(0, p-1)

        hashes.append(Hash(a,b,p,len_shingles))
    return hashes

def make_sign_matrix_hash(hashes, num_perm):
    sign_matrix = np.zeros(shape=(num_perm, len(df.shingles)), dtype=int)
    for doc_index, document in enumerate(df.shingles):
        for hash_index, hash in enumerate(hashes):
            min = float('inf')
            for shingle in document:
                tmp = hash(shingle)

                if tmp < min:
                    min = tmp
            sign_matrix[hash_index][doc_index] = min
    return sign_matrix

In [66]:
%%time
data = dict()
for B in np.arange(1, 10):
    for PERMUTATIONS in np.arange(1, 20):
        num_perm = PERMUTATIONS
        hashes = make_hashes(num_perm)
        for TRESHOLD in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            info = dict()
            if TRESHOLD == 0.5:
                print(B, PERMUTATIONS)
            if (PERMUTATIONS / B) != math.floor(PERMUTATIONS / B):
                break
                
            info["catch rate"] = 1 - np.power((1 - np.power(TRESHOLD, PERMUTATIONS / B)), B)
                
            # set up parameters
            all_shingles = np.array(list(shingles_dict.values()))

            sign_matrix = make_sign_matrix_hash(hashes, num_perm)

            # Locality-Sensitive Hashing

            bands = B

            table = np.split(sign_matrix, bands)
            htable = [dict()] * bands

            candidate_pairs = set()

            for band_index, band in enumerate(table):
                for row_index, row in enumerate(band.T):
                    if row.tostring() in htable[band_index]:
                        for candidate in htable[band_index][row.tostring()]:
                            if row_index == candidate:
                                break
                            candidate_pairs.add((row_index, candidate))
                        htable[band_index][row.tostring()].add(row_index)
                    else:
                        htable[band_index][row.tostring()] = {row_index}

            info["amt cand pairs"] = len(candidate_pairs)
            
            correctness = 0
            for pair in candidate_pairs:
                if pair[0] < pair[1]:
                    if simdict[(pair[0], pair[1])] >= TRESHOLD:
                        correctness += 1
                else:
                    if simdict[(pair[1], pair[0])] >= TRESHOLD:
                        correctness += 1
                        
            correctness /= len(candidate_pairs)
            info["correctness"] = correctness
            
            
            data[(B, PERMUTATIONS, TRESHOLD)] = info

1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 2
2 2
3 2
2 4
5 2
2 6
7 2
2 8
9 2
2 10
11 2
2 12
13 2
2 14
15 2
2 16
17 2
2 18
19 2
1 3
2 3
3 3
4 3
5 3
3 6
7 3
8 3
3 9
10 3
11 3
3 12
13 3
14 3
3 15
16 3
17 3
3 18
19 3
1 4
2 4
3 4
4 4
5 4
6 4
7 4
4 8
9 4
10 4
11 4
4 12
13 4
14 4
15 4
4 16
17 4
18 4
19 4
1 5
2 5
3 5
4 5
5 5
6 5
7 5
8 5
9 5
5 10
11 5
12 5
13 5
14 5
5 15
16 5
17 5
18 5
19 5
1 6
2 6
3 6
4 6
5 6
6 6
7 6
8 6
9 6
10 6
11 6
6 12
13 6
14 6
15 6
16 6
17 6
6 18
19 6
1 7
2 7
3 7
4 7
5 7
6 7
7 7
8 7
9 7
10 7
11 7
12 7
13 7
7 14
15 7
16 7
17 7
18 7
19 7
1 8
2 8
3 8
4 8
5 8
6 8
7 8
8 8
9 8
10 8
11 8
12 8
13 8
14 8
15 8
8 16
17 8
18 8
19 8
1 9
2 9
3 9
4 9
5 9
6 9
7 9
8 9
9 9
10 9
11 9
12 9
13 9
14 9
15 9
16 9
17 9
9 18
19 9
CPU times: user 11min 18s, sys: 114 ms, total: 11min 19s
Wall time: 11min 19s


In [70]:
min = (float("inf"), float("inf"), float("inf"), float("inf"), float("-inf"))
for i in data:
    if (data[i]["correctness"] == 1) and (data[i]["amt cand pairs"] == 10):
        if i[1] < min[1] or 1 - min[4] - min[3] < data[(i[0], i[1], 0.9)]["catch rate"] - data[(i[0], i[1], 0.6)]["catch rate"]:
            min = (i[0], i[1], data[(i[0], i[1], 0.3)]["catch rate"], data[(i[0], i[1], 0.6)]["catch rate"], data[(i[0], i[1], 0.8)]["catch rate"])
    print(i, data[i])
print()
print(min)

(1, 1, 0.1) {'catch rate': 0.09999999999999998, 'amt cand pairs': 1176, 'correctness': 0.011054421768707483}
(1, 1, 0.2) {'catch rate': 0.19999999999999996, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.3) {'catch rate': 0.30000000000000004, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.4) {'catch rate': 0.4, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.5) {'catch rate': 0.5, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.6) {'catch rate': 0.6, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.7) {'catch rate': 0.7, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.8) {'catch rate': 0.8, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 1, 0.9) {'catch rate': 0.9, 'amt cand pairs': 1176, 'correctness': 0.008503401360544218}
(1, 2, 0.1) {'catch rate': 0.010000000000000009, 'amt cand pairs': 35, 'correctness': 0.34285714285714286}
(1, 2, 0