<a href="https://colab.research.google.com/github/asvnpr/ND_Care_Net/blob/master/embed_serv_sim_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Nodes (Content) Embedding

- use Universal Sentence Encoder to embed nodes from our graph, then embed queries and find most similar services


## Colab Setup
Setup our drive, dependencies, etc. so they're accesible in this notebook.

**Ignore this block if you're not running this in a Colab notebook**

You can get most of this block to work by changing some paths and commands

In [0]:
from google.colab import drive
drive.mount('/content/drive/')
prefix = './drive/My Drive/ND_CSE/Year_1/Research:Care-Net/code_and_data'
!echo "Project dir contents:" && ls "$prefix"
!echo -e "\nColab Notebook home dir:" && ls

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
Project dir contents:
211_IN_Data_Parsing.ipynb    nd_care-net.tar.gz
BANE			     Node_Text_Embedding_Doc2Vec.ipynb
BANE_embeddings.ipynb	     Node_Text_Embedding_GPT-2-Copy1.ipynb
binder			     Node_Text_Embedding_sBERT.ipynb
CX_DB8			     README.md
data			     semantic_embeddings_cluster_plot.ipynb
embeddings		     service_embeddings_UMAP.ipynb
embed_serv_sim_search.ipynb  TENE
figures			     TENE_embeddings.ipynb
Graph_Data_Extraction.ipynb  USE_multi_lang_embed_serv_sim_search.ipynb
models

Colab Notebook home dir:
drive  sample_data


In [0]:
# # install conda
# install_conda = ''
# while install_conda not in ('Y', 'N'):
#   install_conda = input("Do you want to install conda?").upper()
# install_conda = True if install_conda == 'Y' else False 
# if install_conda:
#   !if [ ! -f Miniconda3-latest-Linux-x86_64.sh ]; then wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh; fi
#   !bash Miniconda3-latest-Linux-x86_64.sh -bfp /usr/local

In [0]:
# env_file = prefix + '/binder/environment.yml'
# !source activate && conda activate && conda env update --file "$env_file"

## Imports, Global Config, and Data

In [0]:
# !pip3 uninstall -y tensorflow-gpu tensorflow_text
# !pip3 install tensorflow-gpu==2.0rc0  tf-sentencepiece==0.1.83
!pip3 install annoy h2o4gpu tqdm tabulate nltk tensorflow_text==2.0rc0 umap-learn




In [0]:
#following this guide: https://colab.research.google.com/drive/1t4bi7X7zRzwIjdxUrU2hUs7LneqgYLVK#scrollTo=a73qer_zPJLy
# and: https://github.com/tensorflow/hub/blob/master/examples/colab/tf2_semantic_approximate_nearest_neighbors.ipynb

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
  
import tensorflow as tf
import tensorflow_text
import tensorflow_hub as hub
from h2o4gpu.manifold import TSNE
# import matplotlib.pyplot as plt
import numpy as np
from h2o4gpu.metrics.pairwise import cosine_similarity
import seaborn as sns
import annoy
import umap

import json
import os
import csv
import pickle
import pandas as pd

import random
import typing
from itertools import chain
import string

from tqdm import tqdm, trange
from tabulate import tabulate
import logging

# import nltk
# from nltk import sent_tokenize
# from nltk import word_tokenize
# from nltk.corpus import stopwords
# from nltk import download as nltk_dl
# from nltk.stem import WordNetLemmatizer
# nltk_dl('wordnet')

TensorFlow 2.x selected.


In [0]:
# make sure we're using a gpu or hw acceleration before we start embedding!!
# NOTE: not working. clearly connected to GPU
if not tf.test.is_gpu_available() and 'COLAB_TPU_ADDR' not in os.environ:
  print("WARNING!: This notebook is not connected to a GPU nor TPU runtime.")
else:
  print("HW Acceleration should work :)")

Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
HW Acceleration should work :)


In [0]:
# open our datasets
with open(prefix + '/data/services_nodes.json') as sn:
    serv_nodes = json.loads(sn.read())
    
with open(prefix + '/data/services_edgelist.csv') as se:
    serv_edges = csv.reader(se)

# TODO: refacto var name to something more accurate and general
with open(prefix + '/data/HIN_nodes.json') as taxo:
    taxo_nodes = json.loads(taxo.read())

with open(prefix + '/data/code_to_node_num.json') as cn:
    code_trans = json.loads(cn.read())

# load queries file
queries_path = os.path.join(prefix, 'data', 'HIN_references.csv')
with open(queries_path) as qf:
  queries = qf.read()
  queries = queries.split(',')
  # cleanup some bad data. TODO: Fix in notebook that generates the data
  queries = [q for q in queries if q != '']

In [0]:
NODE_TYPE = 'services'
BATCH_SIZE = 256
NUM_NODES = len(queries)
NUM_BATCHES =  NUM_NODES // BATCH_SIZE

MODEL = 'USE'
# MODEL_URL = 'https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3'
# MODEL_URL = 'https://tfhub.dev/google/universal-sentence-encoder-multilingual/2'
MODEL_URL = 'https://tfhub.dev/google/universal-sentence-encoder/4'
MODEL_TYPE = MODEL_URL.split('/')[-2]
MODEL_VER = MODEL_URL.split('/')[-1]

!mkdir -p "$prefix/embeddings/$MODEL/"
print("Using Model {}_{}_v{}".format(MODEL, MODEL_TYPE, MODEL_VER))

Using Model USE_universal-sentence-encoder_v4


In [0]:

tagged_texts = {}

if os.path.exists(os.path.join(prefix, 'data', 'tagged_texts.json')):
  with open(os.path.join(prefix, 'data', 'tagged_texts.json')) as ttf:
    tagged_texts = json.load(ttf)

In [0]:
def save_to_json(data, name):
  path = os.path.join(prefix, 'data', name)
  data_json = json.dumps(data)
  with open(path, 'w') as djf:
    djf.write(data_json)

In [0]:
save_to_json(data=tagged_texts, name='tagged_texts.json')

## Get embeddings from our text data 

In [0]:
embed_path = os.path.join(prefix, 'embeddings', MODEL, "{}_{}_v{}.pkl".format(
    NODE_TYPE, MODEL_TYPE, MODEL_VER
    ))
if os.path.exists(embed_path):
  print("Loading embeddings from {}".format(embed_path))
  with open(embed_path, 'rb') as f:
        tagged_embeds = pickle.load(f)

Loading embeddings from ./drive/My Drive/ND_CSE/Year_1/Research:Care-Net/code_and_data/embeddings/USE/services_universal-sentence-encoder_v4.pkl


In [0]:
def save_to_pickle(data, path=None):
  if path is not None:
    path = os.path.join(prefix, path)
  else:  
    path = os.path.join(prefix, 'embeddings', MODEL, "{}_{}_v{}.pkl".format(
      NODE_TYPE, MODEL_TYPE, MODEL_VER
    ))
  with open(path, 'ab') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

In [0]:
# for tag in tagged_embeds:
    # save_to_pickle(tag)

In [0]:
node_names = [taxo_nodes[idx]['name'] for idx in tagged_texts]
node_embeds = [tagged_embeds[idx]['embed'] for idx in tagged_embeds]
node_texts = ['\n'.join(tagged_texts[idx]) for idx in tagged_embeds]
node_ids = list(tagged_embeds.keys())

In [0]:
def node_to_n_degree_code(node_num, code_names, n=2):
    node = taxo_nodes[node_num]
    codes = node['codes']
    main_code = None
    if len(codes) >= 1:
        main_code = codes[0]
        main_code = main_code[:n]
    else:
        main_code = None

    return main_code

In [0]:
with open(prefix + '/data/code_to_name.json') as ctn:
    code_names = json.loads(ctn.read())

def getLabelEmbeds(code_len='1'):
    tagged_lbl_emb = {}
    lbl_to_avg_emb = {'No Label': []}
    lbl_cnt = {'No Label': 0}
    # TODO change how code_names is saved so it only has the lbls from 2 char codes

    for lbl in code_names[code_len].values():
        # print(lbl)
        lbl_to_avg_emb[lbl] = []
        lbl_cnt[lbl] = 0

    for node_num in tagged_embeds:

        node_codes = taxo_nodes[node_num]['codes']
        main_code = node_to_n_degree_code(node_num, node_codes, n=int(code_len))
        if main_code == None:
            lbl = 'No Label'
        # pseudo label according to 2 first chars of taxonomy code
        else:
            lbl = code_names[code_len][main_code]
        tagged_lbl_emb[node_num] = {'embed': tagged_embeds[node_num]['embed'], 'label': lbl}
        lbl_to_avg_emb[lbl].append(tagged_embeds[node_num]['embed'])
        # print(len(lbl_to_avg_emb[lbl]))
        lbl_cnt[lbl] += 1
    
    # get the mean of each collected label embeddings
    keys_to_del = []
    for lbl in lbl_to_avg_emb:
        if lbl_to_avg_emb[lbl] == []:
            keys_to_del.append(lbl)
            
        else:   
            avg_embed = np.mean(lbl_to_avg_emb[lbl], axis=0)
            lbl_to_avg_emb[lbl] = avg_embed
            # print(avg_embed.shape)
    for lbl in keys_to_del:
        try:
            lbl_to_avg_emb.pop(lbl)    
            lbl_cnt.pop(lbl)   
        except KeyError:
            print("Key not found")    
    # for lbl in lbl_cnt:
    #     print("Label {} has {} elements".format(lbl, lbl_cnt[lbl]))
    print("Nodes have {} labels".format(len(lbl_to_avg_emb)))
    return tagged_lbl_emb, lbl_to_avg_emb, lbl_cnt

In [0]:
lvl1_tagged_embs, lvl1_avg_embeds, lvl1_cnt = getLabelEmbeds(code_len='1')
lvl2_tagged_embs, lvl2_avg_embeds, lvl2_cnt = getLabelEmbeds(code_len='2')

Nodes have 12 labels
Nodes have 73 labels


In [0]:
# lvl1_avg
# lvl2_embeds = [lvl2_avg_embeds[lbl] for lbl in lvl2_avg_embeds]
# lvl2_embed_sim = cosine_similarity(lvl2_embeds)

label_hierarchy = {}
for lbl in lvl1_avg_embeds:
    label_hierarchy[lbl] = set()

# get order of nodes to display their names when hovering
lbl_node_nums = {}
for lbl in lvl1_avg_embeds:
    lbl_node_nums[lbl] = []

for n_idx in lvl1_tagged_embs:
    lvl1_lbl = lvl1_tagged_embs[n_idx]['label']

    label_hierarchy[lvl1_lbl].add(lvl2_tagged_embs[n_idx]['label'])
    # lbl_node_nums[lvl1_lbl].append(n_idx)

# get the embeddings grouped according to lvl1 labels
reordered_embeds = []
reordered_lbls = []
for lbl in label_hierarchy:

    sub_lbls = label_hierarchy[lbl]
    for sub in sub_lbls:
        reordered_embeds.append(lvl2_avg_embeds[sub])
        reordered_lbls.append(sub)

# reordered_names = [taxo_nodes[n_idx]['name'] for n_idx in reordered_node_ids]
embed_sims = cosine_similarity(reordered_embeds)
tick_vals = []
tick_total = 0
for lbl in lvl1_cnt:
    val = lvl1_cnt[lbl]
    tick_vals.append(val // 2)
    tick_total += val

print("Have {} names for x,y axis".format(len(reordered_lbls)))
# print(reordered_embeds)
print("Have {}x{} similarity matrix".format(len(embed_sims), len(embed_sims[0])))

print(list(lvl1_avg_embeds.keys()))

Have 73 names for x,y axis
Have 73x73 similarity matrix
['No Label', 'Basic Needs', 'Consumer Services', 'Criminal Justice and Legal Services', 'Education', 'Environment and Public Health/Safety', 'Health Care', 'Income Support and Employment', 'Individual and Family Life', 'Mental Health and Substance Use Disorder Services', 'Organizational/Community/International Services', 'Target Populations']


In [0]:
import plotly.graph_objects as go
import datetime
import numpy as np
np.random.seed(1)

ticks = list(label_hierarchy.keys())
sim_threshold = 0.6
threshold_sims = [list(map(lambda x: x if x >= sim_threshold else 0.0, row)) for row in embed_sims]

y = reordered_lbls
x = reordered_lbls
z = threshold_sims

def split_text(text, delim=' '):
    txt = ''
    for j, word in enumerate(text.split(delim)):
        if j % 2 == 0:
            txt += word + ' <br>'
        else:
            txt += word + ''
    tick_text[i] = txt

tick_text = list(lvl1_avg_embeds.keys())
for i, text in enumerate(tick_text):
    if '/' in text:
        split_text(text, delim='/')
    else:
        split_text(text)


dx_dy = 6

x_axis = dict(
        tickmode = 'array',
        tickvals = list(range(0, len(tick_text)*dx_dy, dx_dy)),
        ticktext = tick_text
)

y_axis = dict(
        tickmode = 'array',
        tickvals = list(range(0, len(tick_text)*dx_dy, dx_dy)),
        ticktext = tick_text
)

color_scales = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
    'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg',
    'brwnyl', 'bugn', 'bupu', 'burg', 'burgyl', 'cividis', 'curl',
    'darkmint', 'deep', 'delta', 'dense', 'earth', 'edge', 'electric',
    'emrld', 'fall', 'geyser', 'gnbu', 'gray', 'greens', 'greys',
    'haline', 'hot', 'hsv', 'ice', 'icefire', 'inferno', 'jet',
    'magenta', 'magma', 'matter', 'mint', 'mrybm', 'mygbm', 'oranges',
    'orrd', 'oryel', 'peach', 'phase', 'picnic', 'pinkyl', 'piyg',
    'plasma', 'plotly3', 'portland', 'prgn', 'pubu', 'pubugn', 'puor',
    'purd', 'purp', 'purples', 'purpor', 'rainbow', 'rdbu', 'rdgy',
    'rdpu', 'rdylbu', 'rdylgn', 'redor', 'reds', 'solar', 'spectral',
    'speed', 'sunset', 'sunsetdark', 'teal', 'tealgrn', 'tealrose',
    'tempo', 'temps', 'thermal', 'tropic', 'turbid', 'twilight',
    'viridis', 'ylgn', 'ylgnbu', 'ylorbr', 'ylorrd']

color_scale = random.choice(color_scales)
print("Using color scale: {}".format(color_scale))


fig = go.Figure(data=go.Heatmap(
        z=z,
        x=x,
        y=y,
        colorscale = color_scale))

fig.update_layout(
    title='Service-to-Service Similarity',
    xaxis_nticks=len(tick_text),
    yaxis_nticks=len(tick_text),
    width = 1000,
    height = 800,
    xaxis_type = 'category',
    yaxis_type = 'category',
    xaxis = x_axis,
    yaxis = y_axis,
    margin=dict(l=20, r=30, t=50, b=40),
    paper_bgcolor="LightSteelBlue",
    font=dict(
        family="Courier New, monospace",
        size=12,
        # color="#7f7f7f"
        )
    )
    

fig.show()

Using color scale: dense


## Semantic Similarity Search: 
get the most similar items for each of our queries

### Approximate Nearest Neighbors

Using [ANNOY](https://github.com/spotify/annoy)

In [0]:
# adapted from: https://github.com/tensorflow/hub/blob/master/examples/colab/tf2_semantic_approximate_nearest_neighbors.ipynb
# builds an approximate nearest neighbor index (with ANNOY lib)
# used to avoid searching through the entirety of our data when comparing queries
# TODO: experiment with param values


def build_ann_index(embed_data, index_path, vector_length, 
    metric='angular', num_trees=50):
  # Builds an ANNOY index

  annoy_index = annoy.AnnoyIndex(vector_length, metric=metric)
  # Mapping between the item and its identifier in the index
  mapping = {}

  item_counter = 0
  for idx, name, text, embed in embed_data:
      mapping[item_counter] = {'id': idx, 'name': name, 'text': text}
      annoy_index.add_item(item_counter, embed)
      item_counter += 1
      if item_counter % 1000 == 0:
        print('{} items loaded to the index'.format(item_counter))

  print('A total of {} items added to the index'.format(item_counter))

  print('Building the index with {} trees...'.format(num_trees))
  annoy_index.build(n_trees=num_trees)
  print('Index is successfully built.')
  
  print('Saving index to disk...')
  annoy_index.save(index_path)
  print('Index is saved to disk.')
  print("Index file size: {} MB".format(
    round(os.path.getsize(index_path) / float(1024 ** 2), 2)))
  annoy_index.unload()

  print('Saving mapping to disk...')
  with open(index_path + '.mapping', 'wb') as handle:
    pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)
  print('Mapping is saved to disk.')
  print("Mapping file size: {} MB".format(
    round(os.path.getsize(index_path + '.mapping') / float(1024 ** 2), 2)))

In [0]:
# get size of vectors
embed_dim = len(node_embeds[0])
embed_data = list(zip(node_ids, node_names, node_texts, node_embeds))
index_path = os.path.join(prefix, 'data', "{}_{}_{}_ANNOY_index".format(NODE_TYPE, MODEL, MODEL_TYPE))

!rm "$index_path"
!rm "$index_path".mapping

%time build_ann_index(embed_data, index_path, embed_dim)

1000 items loaded to the index
2000 items loaded to the index
3000 items loaded to the index
4000 items loaded to the index
5000 items loaded to the index
6000 items loaded to the index
7000 items loaded to the index
8000 items loaded to the index
9000 items loaded to the index
10000 items loaded to the index
11000 items loaded to the index
12000 items loaded to the index
13000 items loaded to the index
14000 items loaded to the index
15000 items loaded to the index
16000 items loaded to the index
A total of 16547 items added to the index
Building the index with 50 trees...
Index is successfully built.
Saving index to disk...
Index is saved to disk.
Index file size: 44.56 MB
Saving mapping to disk...
Mapping is saved to disk.
Mapping file size: 12.67 MB
CPU times: user 2.06 s, sys: 95.2 ms, total: 2.16 s
Wall time: 2.25 s


In [0]:
# load the index and mapping files
index = annoy.AnnoyIndex(embed_dim, metric='angular')
index.load(index_path, prefault=True)
print('Annoy index is loaded.')
with open(index_path + '.mapping', 'rb') as handle:
  mapping = pickle.load(handle)
print('Mapping file is loaded.')

# print('\nRandom sample of queries\n\n', '\n'.join(random.choices(queries, k=10)))

Annoy index is loaded.
Mapping file is loaded.


In [0]:
# Finds similar items to a given embedding in the ANN index
def find_similar_items(embedding, num_matches=10):
    ids, distances = index.get_nns_by_vector(
    embedding, num_matches, search_k=-1, include_distances=True)
    items = [mapping[i] for i in ids]

    uniq_matches = set()
    filter_items = []
    filter_dists = []
    i = 0
    while len(uniq_matches) < num_matches and i < len(items):
        name = items[i]['name']
        if name not in uniq_matches:
            # print(name)
            uniq_matches.add(name)
            filter_items.append(items[i])
            filter_dists.append(distances[i])
        i += 1
    return filter_items, filter_dists

In [0]:
# Load the TF-Hub module
print("Loading the TF-Hub {} module...".format(MODEL))
%time embed_fn = hub.load(MODEL_URL)
print("TF-Hub module is loaded.")

def extract_embed(queries):
    # Generates the embedding for the query
    query_embedding =  embed_fn(queries).numpy()
    # print(len(query_embedding))
    return query_embedding

def k_query_matches(queries, sample_size=10, k=10, all_queries=False, out=True):
    if all_queries:
        sample_size = len(queries)

    query_samples = random.choices(queries, k=sample_size)
    for query in query_samples:
        # print(query)
        query_embed = extract_embed([query])[0]
        items, dists = find_similar_items(query_embed, num_matches=k)

        if out:
            print("Top-{} most similar items (w.o. duplicates) to query \"{}\":".format(len(items), query))
            for i in range(len(items)):
                item = items[i]
                dist = dists[i]
                name = item['name']
                print("({}) {} (dist={})".format(i+1, name, dist))
            print()
    return items, dists

Loading the TF-Hub USE module...
CPU times: user 13 s, sys: 3.29 s, total: 16.3 s
Wall time: 26.2 s
TF-Hub module is loaded.


In [0]:
print(len(queries))

matches, distances = k_query_matches(queries)

43949
Top-10 most similar items (w.o. duplicates) to query "Westville - New Durham Township Public Library":
(1) Public Library - Noblesville (dist=1.0075379610061646)
(2) Public Library - Fishers (dist=1.0102190971374512)
(3) Public Libraries (dist=1.023892879486084)
(4) Public Library - Harris Branch (dist=1.0322362184524536)
(5) Public Library - Central Branch (dist=1.0325456857681274)
(6) Public Library - Main Branch (dist=1.0376297235488892)
(7) Public Library - Clark Pleasant Branch (dist=1.0448436737060547)
(8) Remington Carpenter Township Public Library (dist=1.0452905893325806)
(9) Public Library Services (dist=1.04533851146698)
(10) Public Library (dist=1.0455808639526367)

Top-8 most similar items (w.o. duplicates) to query "Saint John's Evangelical Lutheran Church":
(1) Adoption Search (dist=1.251501202583313)
(2) Legislative Information Center (dist=1.289251685142517)
(3) General Legal Aid - Indianapolis (dist=1.2913445234298706)
(4) Clerk's Office - Mishawaka (dist=1.2985

In [0]:
# test multilingual aspect
# consultas = [
#   'abogado',
#   'comida',
#   'asistencia legal',
#   'doctor barato',
#   'asistencia menores',
#   'violencia domestica',
#   'adicción drogas',
#   'vacunas',
#   'cuido niños'
# ]

user_query = [input("Write a query to get recommended services: ")]
# k_query_matches(consultas)
matches, distances = k_query_matches(user_query, sample_size=len(user_query))

### Exhaustive Similarity Search
Compare the query against ALL embeddings

In [0]:
def process_batch_queries(queries):

    descr = 'Embedding all queries...'
    with tqdm(total=len(queries), dynamic_ncols=True, desc=descr) as pbar:
        
        y = 0
        tagged_queries = []
        for i in range(0, len(queries), BATCH_SIZE):

            
            # grab BATCH_SIZE number of samples
            # separate list of tuples into separate lists
            qs = queries[i:i+BATCH_SIZE]

            query_embeds = extract_embed(queries=qs)
            # print(query_embeds[0][:10])
            q_ids = list(range(i, i+BATCH_SIZE))
            
            tagged_queries.extend(list(zip(q_ids, query_embeds)))

            y += 1
            pbar.update(BATCH_SIZE)
            if y % 20 == 0 or (NUM_BATCHES - y) < 20:
                print("Processed Batch {}/{}".format(y, NUM_BATCHES))
    return tagged_queries

In [0]:
def semantic_search(tagged_queries, tagged_embeds, metric=cosine_similarity, q_path=None):

    query_dists = []
    if q_path == None:
        q_path = os.path.join('data', 'exhaustive_query_dists_{}_{}.pkl'.format(MODEL, len(queries)))
        full_qpath = os.path.join(prefix,q_path)

    query_dists_len = 0
    if os.path.exists(full_qpath):
        print("Loading query distances from {}".format(q_path))
        with open(full_qpath, 'rb') as f:
            try:
                descr = 'Loading previously calculated query distances...'
                with tqdm(total=len(queries), dynamic_ncols=True, desc=descr) as pbar:
                    while True:
                        # just count the number of lines to avoid filling up memory
                        query = pickle.load(f)
                        query_dists_len += 1
                        pbar.update(1)
            except EOFError:
                pass
            # query_dists = pickle.load(f)
            # query_dists_len = len(query_dists)
            # query_dists = []
        print("Loaded query distances of length: {}".format(query_dists_len))
    else:
        query_dists = []
    
    if query_dists_len == len(tagged_queries):
        return query_dists

    descr="Calculating query distances..."
    with tqdm(total=len(queries), dynamic_ncols=True, desc=descr) as pbar:
        pbar.update(query_dists_len)

        for i in range(query_dists_len, len(tagged_queries)):
            tag_q = tagged_queries[i]
            q_id, q_embed = tag_q
            # will hold list of tuple of (node_id, node_embed_dist)
            # query_dists[q_id] = []
            # Compare against all nodes
            # put all embeddings in a matrix to compare in a single call
            # (also speeds up if using gpu)
            embed_keys = list(tagged_embeds.keys())
            embeds = list(map(lambda k: tagged_embeds[k]['embed'], embed_keys))
            dists = metric(embeds, [q_embed])
            tagged_dists = list(zip(embed_keys, dists))
            query_dists.append(tagged_dists)
                    # pbar2.update(1)
            if q_id % 1000 == 0:
                # save "checkpoints" periodically in case (Colab) runtime is
                # killed or times out (tqdm estimates 32h to process!!)
                full_q_path = os.path.join(q_path)
                for qd in query_dists:
                    save_to_pickle(qd, path=full_q_path)
                
                # reset size so memory doesn't fill up
                query_dists = []
            pbar.update(1)

    return query_dists

def semantic_results(q_idx, query_dists, top_k=10):
  # sort tagged dists 
  results = sorted(query_dists, key=lambda x: x[1][0], reverse=True)[:top_k]
  ds = [qd[1][0] for qd in results]
  print(ds[:10])
  uniq_res = set()
  dists = []
  node_ids = []
  # print(max(results, key=lambda x: x[1]))
  i = 0
  while len(uniq_res) < top_k and i < len(results):
    node_idx, q_e_dist = query_dists[i]
    node_name = taxo_nodes[node_idx]['name']
    if node_name not in uniq_res:
        uniq_res.add(node_name)
        dists.append(q_e_dist)
        node_ids.append(node_idx)
    i+=1
    
  
  return node_ids, list(uniq_res), dists

In [0]:
q_samples = random.choices(queries, k=len(queries))

q_embed_path = 'embeddings/{}/tagged_query_embeds.pkl'.format(MODEL)
if os.path.exists(os.path.join(prefix, q_embed_path)):
    with open(os.path.join(prefix, q_embed_path), 'rb') as qef:
        tagged_q_embeds = pickle.load(qef)
else:
    tagged_q_embeds = process_batch_queries(q_samples)

In [0]:
# save_to_pickle(tagged_q_embeds, path=q_embed_path)

In [0]:
query_dists = semantic_search(tagged_q_embeds, tagged_embeds)
k = 10

Loading previously calculated query distances...:   0%|          | 0/43949 [00:00<?, ?it/s]

Loading query distances from data/exhaustive_query_dists_USE_43949.pkl


Loading previously calculated query distances...:  22%|██▏       | 9455/43949 [04:55<17:56, 32.05it/s]
Calculating query distances...:  22%|██▏       | 9456/43949 [00:00<00:00, 85740.71it/s]

Loaded query distances of length: 9455


Calculating query distances...: 100%|██████████| 43949/43949 [1:26:45<00:00, 11.54it/s]


In [0]:
query_dists[0]

In [0]:
# list of (node_id, distance_arr)
q_dists_samples = random.choices(query_dists, k=10)
for i, q_dists in enumerate(q_dists_samples):
  # query = queries[q_idx]
  query = q_samples[i]
  node_ids, top_k, dists = semantic_results(q_idx, q_dists)

  print("Top-{} most similar items to query \"{}\":".format(k, query))
  for i in range(len(top_k)):
    n_idx, node_name, dist = node_ids[i], top_k[i], dists[i]
    print(dist)
    print("({}) {} (dist={})".format(i+1, node_name, dist[0]))
  print()

[0.11564456, 0.11564456, 0.11512642, 0.115126416, 0.11093547, 0.10876865, 0.10403679, 0.10298613, 0.102604076, 0.101212904]
Top-10 most similar items to query "Victim-Offender Conferencing":
[-0.00485693]
(1) School District (dist=-0.004856929183006287)
[-0.05541714]
(2) Computer Classes (dist=-0.05541714280843735)
[0.03469025]
(3) Clothing Pantry (dist=0.0346902534365654)
[0.0170226]
(4) Ex-Offender Services (dist=0.017022598534822464)
[0.00292454]
(5) Homeless Services (dist=0.002924541011452675)
[0.0575737]
(6) Job Search Assistance (dist=0.057573698461055756)
[0.0344977]
(7) Special Education Advocacy (dist=0.0344977006316185)
[-0.02212443]
(8) God's Kitchen (dist=-0.022124433889985085)
[0.00489016]
(9) Prosecutor (dist=0.004890155047178268)
[-0.02624138]
(10) Food Pantry (dist=-0.026241380721330643)

[0.1189892, 0.11256465, 0.11157247, 0.11030922, 0.10884496, 0.10884493, 0.10884493, 0.10884493, 0.10884492, 0.10884492]
Top-10 most similar items to query "City Police":
[0.07481994]


In [0]:
q_path="data/{}_query_distances_{}_{}_v{}.pkl".format(len(query_dists), MODEL, MODEL_TYPE, MODEL_VER)
save_to_pickle(query_dists, path=q_path)