In [12]:
import pickle
import numpy as np
from scipy.sparse import csr_matrix

with open("data/mol_bits.pkl", "rb") as f:
    data = pickle.load(f)


# Extracting unique features for sparse matrix dims
all_feats = set()
for feats_set in data.values():
    all_feats.update(feats_set)

n_feats = max(all_feats) + 1  # 0-based indexing

# Creating sparse matrix mol feats
rows, cols = [], []
for i, (uniprot_id, feats_set) in enumerate(data.items()):
    rows.extend([i] * len(feats_set))
    cols.extend(feats_set)

mol_ids = list(data.keys())

mol_feats = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(len(data), n_feats))

mol_feats = mol_feats.T  # Transposing to have features as rows

# LSH implementation
- MinHash
- Buckets


In [31]:
import hashlib


def create_hash(vector_size=mol_feats.shape[0]):
    return np.random.permutation(vector_size)


def get_signature_matrix(sparse_mat, sig_size=50):
    permutations = np.zeros((sparse_mat.shape[0], sig_size), dtype=int)
    for i in range(sig_size):
        permutations[:, i] = create_hash()

    # signature matrix
    signature_matrix = np.zeros((permutations.shape[1], sparse_mat.shape[1]))

    # minhashing
    for i in range(permutations.shape[1]):
        perm = permutations[:, i]
        sigs = sparse_mat[perm, :].argmax(axis=0)
        signature_matrix[i, :] = sigs

    return signature_matrix


def split_sig(sig: np.array, b: int):
    return np.vsplit(sig, b)


def hash_bands_to_buckets(bands, num_buckets=1000):

    hashed_bands = {}

    n_bands = len(bands)

    for i in range(n_bands):
        band = bands[i]
        for col in range(band.shape[1]):
            col_hash = hashlib.sha256(band[:, col].tobytes()).hexdigest()
            bucket = int(col_hash, 16) % num_buckets

            if bucket not in hashed_bands.keys():
                hashed_bands[bucket] = set()
            hashed_bands[bucket].add(col)

    return hashed_bands


def jaccard_similarity(ids):
    id1, id2 = ids
    set1 = set(data[mol_ids[id1]])
    set2 = set(data[mol_ids[id2]])
    return len(set1.intersection(set2)) / len(set1.union(set2))


def fetch_pairs(buckets, threshold=0.8):
    candidates = []
    for proteins in buckets.values():
        if len(proteins) > 1 and len(proteins) < 3:
            sim = jaccard_similarity(tuple(proteins))
            if sim >= threshold:
                candidates.append(tuple(proteins))
    return candidates

In [30]:
# SIG
SIG = get_signature_matrix(mol_feats, sig_size=100)

# split signature into bands
b = 10
bands = split_sig(SIG, b)

# hash bands to buckets
buckets = hash_bands_to_buckets(bands, num_buckets=1_000_000)

candidates = fetch_pairs(buckets, threshold=0.8)

4 49669
CHEMBL2022247 CHEMBL466679
68516 5
CHEMBL488885 CHEMBL2022248
45637 13
CHEMBL2164609 CHEMBL2022256
15400 18
CHEMBL49231 CHEMBL2022578
34566 23
CHEMBL259337 CHEMBL2022583
25 52314
CHEMBL1688473 CHEMBL223744
35 50612
CHEMBL2047158 CHEMBL3086744
59044 36
CHEMBL1802030 CHEMBL2047159
34674 37
CHEMBL516188 CHEMBL2047160
40 21180
CHEMBL2047163 CHEMBL4465318
41 8492
CHEMBL2047164 CHEMBL4218625
43 11207
CHEMBL2047166 CHEMBL420517
44 54775
CHEMBL2047167 CHEMBL140103
9872 45
CHEMBL216464 CHEMBL2047168
63223 47
CHEMBL1972659 CHEMBL2047170
48 68772
CHEMBL2048610 CHEMBL1824048
3042 52
CHEMBL3897499 CHEMBL2048614
60 26109
CHEMBL2048621 CHEMBL4067246
37282 62
CHEMBL110652 CHEMBL2048623
35409 63
CHEMBL1671923 CHEMBL2048624
8866 69
CHEMBL4469537 CHEMBL1688452
72 10262
CHEMBL449588 CHEMBL317157
73 1450
CHEMBL604126 CHEMBL3670591
74 37165
CHEMBL1829173 CHEMBL342117
76 1486
CHEMBL2164242 CHEMBL3659198
81 85
CHEMBL4082711 CHEMBL4097527
113 86
CHEMBL4792027 CHEMBL3915620
91 62612
CHEMBL4091552 CHEMBL

In [None]:
# hash bands to buckets
buckets = hash_bands_to_buckets(bands, num_buckets=1_000_000)

In [None]:
buckets

{692071: {0},
 671828: {1, 2, 48902},
 555745: {3, 69099, 71866},
 10627: {4, 15},
 488106: {5, 3215, 38402},
 553865: {6},
 321395: {7},
 249920: {8, 19310},
 474602: {9},
 498538: {10, 5934, 52342},
 194688: {11, 26272},
 433367: {12, 61995, 69769},
 584879: {13, 17607},
 336179: {14},
 721959: {16},
 391399: {17, 26, 27},
 351766: {18},
 865776: {19},
 850254: {20},
 403688: {21, 30187, 40569, 40582, 71970},
 348851: {22, 22367},
 333105: {23},
 843190: {24},
 106201: {25},
 940389: {28, 29, 30},
 636979: {31, 32, 33, 44, 3611, 64390},
 803902: {34},
 133562: {35, 21991},
 814358: {36, 68, 78, 33903},
 907231: {37, 13706, 19984},
 461011: {38, 45572, 45577},
 270331: {39, 45357},
 375281: {40},
 123844: {41},
 618193: {42, 31359, 39202},
 345932: {43, 45},
 959278: {46},
 424069: {47},
 242843: {48, 17850},
 998387: {49},
 461541: {50, 37809, 37810, 37813, 67545},
 582297: {51, 65341},
 571045: {52},
 505297: {53, 21601, 45732},
 371760: {54, 16122, 30831},
 618680: {55},
 700402: {