<a href="https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/srihari-dev/notebooks/tpu_index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!gdown --id 1-8nsWLseynVj6Z9-E12w1ywnffnqJftm
!gdown --id 1UszbNYQnlNrAcPQkBwvb1wKX21oRPiqb

Downloading...
From: https://drive.google.com/uc?id=1-8nsWLseynVj6Z9-E12w1ywnffnqJftm
To: /content/Uembeds306Epochs.npy
2.59GB [00:21, 119MB/s]
Downloading...
From: https://drive.google.com/uc?id=1UszbNYQnlNrAcPQkBwvb1wKX21oRPiqb
To: /content/Vembeds306Epochs.npy
2.59GB [00:38, 67.5MB/s]


In [1]:
%tensorflow_version 2.x
from concurrent.futures import ProcessPoolExecutor
import numpy as np
import tensorflow as tf
from time import time
from tqdm import tqdm_notebook as tqdm
print('TensorFlow:', tf.__version__)

TensorFlow 2.x selected.
TensorFlow: 2.1.0-rc1


In [0]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.MirroredStrategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [3]:
workers = ['/job:worker/replica:0/task:0/device:TPU:'+str(i) for i in range(8)]
workers

['/job:worker/replica:0/task:0/device:TPU:0',
 '/job:worker/replica:0/task:0/device:TPU:1',
 '/job:worker/replica:0/task:0/device:TPU:2',
 '/job:worker/replica:0/task:0/device:TPU:3',
 '/job:worker/replica:0/task:0/device:TPU:4',
 '/job:worker/replica:0/task:0/device:TPU:5',
 '/job:worker/replica:0/task:0/device:TPU:6',
 '/job:worker/replica:0/task:0/device:TPU:7']

In [0]:
class Index:
    def __init__(self, u, v, worker):
        self.embeddings = tf.math.l2_normalize(u, axis=1) + tf.math.l2_normalize(v, axis=1)
        self.squared_norms_embeddings = tf.expand_dims(tf.square(tf.norm(self.embeddings, axis=1)), axis=0)
        self.worker = worker

    @tf.function
    def search(self, query_vector, top_k=None):
      with tf.device(worker):
        squared_norms_query_vector = tf.expand_dims(tf.square(tf.norm(query_vector, axis=1)), axis=0)
        dot_product = tf.reduce_sum(tf.multiply(self.embeddings, query_vector), axis=1)
        distances = tf.maximum(self.squared_norms_embeddings + squared_norms_query_vector - 2 * dot_product, 0)
        sorted_indices =  tf.argsort(distances)
        if top_k:
            sorted_indices = sorted_indices[..., :top_k]
        nearest_distances = tf.reshape(tf.gather(distances[0], sorted_indices), shape=[-1, 1])
        return nearest_distances[..., 0], sorted_indices[0]

In [5]:
u_embeddings = np.load('Uembeds306Epochs.npy')
v_embeddings = np.load('Vembeds306Epochs.npy')
u_embeddings.shape, v_embeddings.shape

((1262996, 512), (1262996, 512))

In [0]:
# Discarding last 4 vectors to make number of vectors divisible by 8
u_embeddings = np.split(u_embeddings[:-4], 8, axis=0)
v_embeddings = np.split(v_embeddings[:-4], 8, axis=0)
vecs_per_index = 157874

In [7]:
## Place 1/8 of total embeddings on each TPU core
indices = []
for i, worker in enumerate(workers):
  with tf.device(worker):
    print('Building index with {} vectors on {}'.format(u_embeddings[i].shape[0],worker))
    indices.append(Index(u_embeddings[i], v_embeddings[i], worker))

Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:0
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:1
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:2
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:3
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:4
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:5
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:6
Building index with 157874 vectors on /job:worker/replica:0/task:0/device:TPU:7


In [0]:
def search(xq):
  D, I = [], []
  for i in range(8):
    print('Search running in index: {}'.format(indices[i].worker))
    d, idx = indices[i].search(xq, 1)
    D.append(d.numpy()[0])
    I.append(i*vecs_per_index + idx.numpy()[0])

  id_sorted = np.argsort(D)
  D = np.array(D)[id_sorted]
  I = np.array(I)[id_sorted]
  return D, I

In [28]:
n = 10086
split = 7 # [0, 7] . # Pick nth vector from given split
actual_n = vecs_per_index*split + n

xq = tf.nn.l2_normalize(u_embeddings[split][n]) + tf.nn.l2_normalize(v_embeddings[split][n])
xq = tf.reshape(xq, [1, -1])

s = time()
D, I = search(xq)
e = time()

print('\nActual ID   :', actual_n)
print('Result ID   :', I[0], '\n')

print('Neighbours   :', I )
print('Distances   :', np.round(D, 4))
print('\nTime taken  :', np.round(e-s, 2), 'secs')
# First search runs slow, because tf.function traces the function
# only for the first invocation, following invocations should run fine

Search running in index: /job:worker/replica:0/task:0/device:TPU:0
Search running in index: /job:worker/replica:0/task:0/device:TPU:1
Search running in index: /job:worker/replica:0/task:0/device:TPU:2
Search running in index: /job:worker/replica:0/task:0/device:TPU:3
Search running in index: /job:worker/replica:0/task:0/device:TPU:4
Search running in index: /job:worker/replica:0/task:0/device:TPU:5
Search running in index: /job:worker/replica:0/task:0/device:TPU:6
Search running in index: /job:worker/replica:0/task:0/device:TPU:7

Actual ID   : 1115204
Result ID   : 1115204 

Neighbours   : [1115204   84123  604532  881265  190994 1046794  466390  683556]
Distances   : [0.     1.9193 1.9222 1.9401 1.9465 1.9475 1.9555 1.9613]

Time taken  : 0.24 secs


#### Checking accuracy

In [0]:
n_test = 5000
random_n = np.random.randint(0, vecs_per_index, n_test)
random_split = np.random.randint(0, 8, n_test)

In [25]:
y_true= []
y_pred = []
s = time()
for n, split in tqdm(zip(random_n, random_split), total=n_test):
  xq = tf.nn.l2_normalize(u_embeddings[split][n]) + tf.nn.l2_normalize(v_embeddings[split][n])
  xq = tf.reshape(xq, [1, -1])
  actual_n = vecs_per_index*split + n
  D, I = [], []
  for i in range(8):
    d, idx = indices[i].search(xq, 1)
    D.append(d.numpy()[0])
    I.append(i*vecs_per_index + idx.numpy()[0])
  id_sorted = np.argsort(D)
  y_pred.append(np.array(I)[id_sorted])
  y_true.append(actual_n)
e = time()
y_true = np.array(y_true)
y_pred = np.array(y_pred)
print('\nTime taken per search    :', np.round(e-s, 2) / n_test, 'secs')
print('Accuracy                 :', np.sum(y_true == y_pred[:, 0]) / n_test)

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))



Time taken per search    : 0.24776199999999998 secs
Accuracy                 : 1.0
