# Low-latency item-to-item recommendation system 

## Part 1 - Creating ANN index

## Overview

This notebook is a part of the [**Low-latency item-to-item recommendation system** ML Engineering blueprint](https://github.com/jarokaz/analytics-componentized-patterns/tree/master/retail/recommendation-system/bqml-ann).

The blueprint provides guidance and code samples for how to develop and operationalize a near real-time item-to-itme recommendations system that utilizes BigQuery, BigQuery ML and AI Platform ANN Service.

This notebook demonstrates how to create and deploy an ANN index using embeddings exported in Part 1 - Creating embeddings. In the notebook you go through the following steps.

1. Creating an ANN Index using the embeddings data in the JSONL format.
2. Creating and ANN Endpoint. 
3. Deploying the ANN Index to the ANN Endpoint.

This notebook was designed to run on [AI Platform Notebooks](https://cloud.google.com/ai-platform/notebooks/docs) using the standard TensorFlow 2.3 image. Your notebook instance should be in the same project as the AI Platform ANN Service.

While AI Platform ANN Service is in the Experimental stage your project must be allow-listed before using the service. Contact `gcp-ann-feedback@` to allow list your project and user id.

## Setting up the notebook's environment

### Configuring AI Platform ANN Service

#### Enable the required Cloud APIs

You need to enable the following APIs to use the ANN service:
* aiplatform.googleapis.com
* servicenetworking.googleapis.com
* compute.googleapis.com

#### Prepare a VPC network

In the experimental release, ANN service is only accessible using private endpoints. Before using the service you need to have a [VPC network](https://cloud.google.com/vpc) configured with [private services access](https://cloud.google.com/vpc/docs/configure-private-services-access). You can use the `default` VPC or create a new one.

The below instructions are for a VPC that was created with auto subnets and regional dynamic routing mode (defaults). It is recommended that you execute the below commands from Cloud Shell, using the account with appropriate permissions - `roles/compute.networkAdmin`.

1. Set environment variables for your project ID, the name of your VPC network, and the name of your reserved range of addresses. The name of the reserved range can be an arbitrary name. It is for display only.

```
PROJECT_ID=<your-project-id>
gcloud config set project $PROJECT_ID
NETWORK_NAME=<your-VPC-network-name>
PEERING_RANGE_NAME=google-reserved-range

```

2. Reserve an IP range for Google services. The reserved range should be large enought to accommodate all peered services. The below command reserves a CIDR block with mask /16

```
gcloud compute addresses create $PEERING_RANGE_NAME \
  --global \
  --prefix-length=16 \
  --description="peering range for Google service: AI Platform Online Prediction" \
  --network=$NETWORK_NAME \
  --purpose=VPC_PEERING \
  --project=$PROJECT_ID

```

3. Create a private connection to establish a VPC Network Peering between your VPC network and the Google services network.

```
gcloud services vpc-peerings connect \
  --service=servicenetworking.googleapis.com \
  --network=$NETWORK_NAME \
  --ranges=$PEERING_RANGE_NAME \
  --project=$PROJECT_ID

```


### Importing notebook dependencies

In [1]:
import base64
import datetime
import logging
import os
import json
import pandas as pd
import time

import grpc
import match_pb2
import match_pb2_grpc

import google.auth
import numpy as np
import tensorflow.io as tf_io

from google.cloud import bigquery
from typing import List, Optional, Text, Tuple

### Configure GCP environment

Set the following constants to the values reflecting your environment:

* `PROJECT_ID` - your GCP project ID
* `PROJECT_NUMBER` - your GCP project number
* `DATA_LOCATION` - a GCS location of the embeddings (JSONL) files exported in Part 1
* `VPC_NAME` - a name of the GCP VPC to use for the index deployments. Use the name of the VPC prepared in the previous section. 
* `REGION` - a compute region. Don't change the default - `us-central` - while the ANN Service is in the experimental stage


In [2]:
PROJECT_ID = 'jk-mlops-dev'
PROJECT_NUMBER = '895222332033'
DATA_LOCATION = 'gs://jk-ann-staging/embeddings' 
REGION = 'us-central1'
VPC_NAME = 'default'

MATCH_SERVICE_PORT = 10000

## Creating an ANN index deployment

You will use the REST interface to invoke the AI Platform ANN Service control plane APIs.

### Define helper classes to encapsulate the ANN Service REST API.

Currently, there is no Python client that encapsulates the ANN Service API. The below code snippet defines a simple wrapper that encapsulates a subset of REST APIs used in this notebook.

In [5]:
class ANNClient(object):
    """Base ANN Service client."""
    
    def __init__(self, project_id, project_number, region):
        credentials, _ = google.auth.default()
        self.authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
        self.ann_endpoint = f'{region}-aiplatform.googleapis.com'
        self.ann_parent = f'https://{self.ann_endpoint}/v1alpha1/projects/{project_id}/locations/{region}'
        self.project_id = project_id
        self.project_number = project_number
        self.region = region
        
    def wait_for_completion(self, operation_id, message, sleep_time):
        """Waits for a completion of a long running operation."""
        
        api_url = f'{self.ann_parent}/operations/{operation_id}'

        start_time = datetime.datetime.utcnow()
        while True:
            response = self.authed_session.get(api_url)
            if 'done' in response.json().keys():
                logging.info('Operation completed!')
                break
            elapsed_time = datetime.datetime.utcnow() - start_time
            logging.info('{}. Elapsed time since start: {}.'.format(
                message, str(elapsed_time)))
            time.sleep(sleep_time)
            
            print(response)
    
        return response.json()['response']


class IndexClient(ANNClient):
    """Encapsulates a subset of control plane APIs 
    that manage ANN indexes."""

    def __init__(self, project_id, project_number, region):
        super().__init__(project_id, project_number, region)

    def create_index(self, display_name, description, metadata):
        """Creates an ANN Index."""
    
        api_url = f'{self.ann_parent}/indexes'
    
        request_body = {
            'display_name': display_name,
            'description': description,
            'metadata': metadata
        }
    
        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
        
        return operation_id

    def list_indexes(self, display_name=None):
        """Lists all indexes with a given display name or
        all indexes if the display_name is not provided."""
    
        if display_name:
            api_url = f'{self.ann_parent}/indexes?filter=display_name="{display_name}"'
        else:
            api_url = f'{self.ann_parent}/indexes'

        response = self.authed_session.get(api_url).json()

        return response['indexes'] if response else []
    
    def delete_index(self, index_id):
        """Deletes an ANN index."""
        
        api_url = f'{self.ann_parent}/indexes/{index_id}'
        response = self.authed_session.delete(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)


class IndexDeploymentClient(ANNClient):
    """Encapsulates a subset of control plane APIs 
    that manage ANN endpoints and deployments."""
    
    def __init__(self, project_id, project_number, region):
        super().__init__(project_id, project_number, region)

    def create_endpoint(self, display_name, vpc_name):
        """Creates an ANN endpoint."""
    
        api_url = f'{self.ann_parent}/indexEndpoints'
        network_name = f'projects/{self.project_number}/global/networks/{vpc_name}'

        request_body = {
            'display_name': display_name,
            'network': network_name
        }

        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
    
        return operation_id
    
    def list_endpoints(self, display_name=None):
        """Lists all ANN endpoints with a given display name or
        all endpoints in the project if the display_name is not provided."""
        
        if display_name:
            api_url = f'{self.ann_parent}/indexEndpoints?filter=display_name="{display_name}"'
        else:
            api_url = f'{self.ann_parent}/indexEndpoints'

        response = self.authed_session.get(api_url).json()
 
        return response['indexEndpoints'] if response else []
    
    def delete_endpoint(self, endpoint_id):
        """Deletes an ANN endpoint."""
        
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'
        
        response = self.authed_session.delete(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)
        
        return response.json()
    
    def create_deployment(self, display_name, deployment_id, endpoint_id, index_id):
        """Deploys an ANN index to an endpoint."""
    
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:deployIndex'
        index_name = f'projects/{self.project_number}/locations/{self.region}/indexes/{index_id}'

        request_body = {
            'deployed_index': {
                'id': deployment_id,
                'index': index_name,
                'display_name': display_name
            }
        }

        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        operation_id = response.json()['name'].split('/')[-1]
        
        return operation_id
    
    def get_deployment_grpc_ip(self, endpoint_id, deployment_id):
        """Returns a private IP address for a gRPC interface to 
        an Index deployment."""
  
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}'

        response = self.authed_session.get(api_url)
        if response.status_code != 200:
            raise RuntimeError(response.text)
            
        endpoint_ip = None
        if 'deployedIndexes' in response.json().keys():
            for deployment in response.json()['deployedIndexes']:
                if deployment['id'] == deployment_id:
                    endpoint_ip = deployment['privateEndpoints']['matchGrpcAddress']
                    
        return endpoint_ip

    
    def delete_deployment(self, endpoint_id, deployment_id):
        """Undeployes an index from an endpoint."""
        
        api_url = f'{self.ann_parent}/indexEndpoints/{endpoint_id}:undeployIndex'
        
        request_body = {
            'deployed_index_id': deployment_id
        }
    
        response = self.authed_session.post(api_url, data=json.dumps(request_body))
        if response.status_code != 200:
            raise RuntimeError(response.text)
        
        return response
    
    
index_client = IndexClient(PROJECT_ID, PROJECT_NUMBER, REGION)
deployment_client = IndexDeploymentClient(PROJECT_ID, PROJECT_NUMBER, REGION)

### Create an index

In [6]:
index_display_name = 'Song embeddings test 2'
index_description = 'Song embeddings test 2'
index_metadata = {
    'contents_delta_uri': DATA_LOCATION,
    'config': {
        'dimensions': 50,
        'approximate_neighbors_count': 50,
        'distance_measure_type': 'DOT_PRODUCT_DISTANCE',
        'feature_norm_type': 'UNIT_L2_NORM',
        'tree_ah_config': {
            'child_node_count': 1000,
            'max_leaves_to_search': 100
         }
    }
}

logging.getLogger().setLevel(logging.INFO)

In [None]:
operation_id = index_client.create_index(index_display_name, 
                                          index_description,
                                          index_metadata)

response = index_client.wait_for_completion(operation_id, 'Creating index', 20)
print(response)

#### List all indexes with the set display name

In [7]:
#indexes = index_client.list_indexes(index_display_name)
indexes = index_client.list_indexes()

for index in indexes:
    print(index['name'])

if indexes: 
    index_id = index['name'].split('/')[-1]
    print(f'Index: {index_id}')
    print('will be used for the deployment.')
else:
    print('No indexes available for deployment')

projects/895222332033/locations/us-central1/indexes/1160802803954745344
Index: 1160802803954745344
will be used for the deployment.


### Create the index deployment

#### Create an index endpoint

In [None]:
deployment_display_name = 'Song embeddings endpoint test 2'

In [None]:
operation_id = deployment_client.create_endpoint(deployment_display_name, VPC_NAME)

response = index_client.wait_for_completion(operation_id, 'Waiting for endpoint', 10)
print(response)

#### List all endpoints

In [19]:
#endpoints = deployment_client.list_endpoints(deployment_display_name)
endpoints = deployment_client.list_endpoints()

for endpoint in endpoints:
    print(endpoint)
    print(endpoint['name'])
    
if endpoints: 
    endpoint_id = endpoint['name'].split('/')[-1]
    print(f'Endpoint: {endpoint_id}')
    print('will be used for the deployment.')
else:
    print('No endpoints available for deployment')


No endpoints available for deployment


#### Deploy the index to the endpoint

##### Set the deployment ID

The ID must be unique within your project

In [None]:
deployment_display_name = 'Songs index deployment test 2'
deployed_index_id = 'songs_index_deployed_test_2'

##### Deploy the index

In [None]:
operation_id = deployment_client.create_deployment(deployment_display_name, 
                                                   deployed_index_id,
                                                   endpoint_id,
                                                   index_id)

response = index_client.wait_for_completion(operation_id, 'Waiting for deployment', 10)
print(response)

#### Retrieve the gRPC private endpoint for ANN Match service

In [None]:
deployed_index_ip = deployment_client.get_deployment_grpc_ip(endpoint_id, deployed_index_id)
endpoint = f'{deployed_index_ip}:{MATCH_SERVICE_PORT}'
print(f'gRPC endpoint for the: {deployed_index_id} deployment is: {endpoint}')

## Querying the ANN service

You will use the gRPC interface to query the deployed index.

### Create a helper wrapper around the Match Service gRPC API.

In [None]:
class MatchService(object):
    def __init__(self, endpoint, deployed_index_id):
        self.endpoint = endpoint
        self.deployed_index_id = deployed_index_id

    def single_match(
        self,
        embedding: List[float],
        num_neighbors: int) -> List[Tuple[str, float]]:
    
        match_request = match_pb2.MatchRequest(deployed_index_id=self.deployed_index_id,
                                               float_val=embedding,
                                               num_neighbors=num_neighbors)
        with grpc.insecure_channel(endpoint) as channel:
            stub = match_pb2_grpc.MatchServiceStub(channel)
            response = stub.Match(match_request)
    
        return [(neighbor.id, neighbor.distance) for neighbor in response.neighbor]


    def batch_match(
        self,
        embeddings: List[List[float]],
        num_neighbors: int) -> List[List[Tuple[str, float]]]:
    
        match_requests = [
            match_pb2.MatchRequest(deployed_index_id=self.deployed_index_id,
                                   float_val=embedding,
                                   num_neighbors=num_neighbors)
            for embedding in embeddings]
    
        batches_per_index = [
            match_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
                deployed_index_id=self.deployed_index_id,
                requests=match_requests)]
    
        batch_match_request = match_pb2.BatchMatchRequest(
            requests=batches_per_index)
    
        with grpc.insecure_channel(endpoint) as channel:
            stub = match_pb2_grpc.MatchServiceStub(channel)
            response = stub.BatchMatch(batch_match_request)
        
        matches = []
        for batch_per_index in response.responses:
            for match in batch_per_index.responses:
                matches.append(
                    [(neighbor.id, neighbor.distance) for neighbor in match.neighbor])
        
        return matches
    
match_service = MatchService(endpoint, deployed_index_id)

### Prepare sample data

In [None]:
%%bigquery df_embeddings

SELECT id, embedding
FROM `recommendations.item_embeddings` 
LIMIT 10

In [None]:
df_embeddings

In [None]:
sample_embeddings = [list(embedding) for embedding in df_embeddings.iloc[0:2]['embedding']]

### Run a single match query

In [None]:
single_match = match_service.single_match(sample_embeddings[0], 10)
single_match

### Run a batch match query

In [None]:
batch_match = match_service.batch_match(sample_embeddings, 5)
batch_match

## Clean up

**WARNING**

The below code will delete all ANN deployments, endpoints, and indexes in the configured project.

### Delete index deployments and endpoints

In [18]:
for endpoint in deployment_client.list_endpoints():
    endpoint_id = endpoint['name'].split('/')[-1]
    if 'deployedIndexes' in endpoint.keys():
        for deployment in endpoint['deployedIndexes']:
            print('   Deleting index deployment: {} in the endpoint: {} '.format(deployment['id'], endpoint_id))
            deployment_client.delete_deployment(endpoint_id, deployment['id'])
    print('Deleting endpoint: {}'.format(endpoint['name']))
    deployment_client.delete_endpoint(endpoint_id)

Deleting endpoint: projects/895222332033/locations/us-central1/indexEndpoints/4009329568266584064
Deleting endpoint: projects/895222332033/locations/us-central1/indexEndpoints/2270940112101572608
Deleting endpoint: projects/895222332033/locations/us-central1/indexEndpoints/6315172577480278016
Deleting endpoint: projects/895222332033/locations/us-central1/indexEndpoints/1703486559052890112


### Delete indexes

## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**