# Low-latency item-to-item recommendation system 

## Part 1 - Creating embeddings


## Overview

This notebook is a part of the [**Low-latency item-to-item recommendation system** ML Engineering blueprint](https://github.com/jarokaz/analytics-componentized-patterns/tree/master/retail/recommendation-system/bqml-ann).

The blueprint provides guidance and code samples for how to develop and operationalize a near real-time item-to-itme recommendations system that utilizes BigQuery, BigQuery ML and AI Platform ANN Service.

This notebook demonstrates how to create item embeddings using BQML Matrix Factorization model and how to export them in the JSONL format compatible with the ANN Service's ingestion schema. In the notebook you go through the following steps.

1. Preparing the training data based on the public `bigquery-samples.playlists` dataset.
2. Training the BQML Matrix Factorization model. 
3. Exploring the trained embeddings.
4. Exporting the embeddings.

Note that training a BigQuery ML Matrix Factorization model requires slot reservations. For more information, you can read up on how to set up flex slots [programmatically](https://medium.com/google-cloud/optimize-bigquery-costs-with-flex-slots-e06ec5e4aa90) or via the [BigQuery UI](https://cloud.google.com/bigquery/docs/reservations-workload-management#getting-started-with-bigquery-reservations).

This notebook was designed to run on [AI Platform Notebooks](https://cloud.google.com/ai-platform/notebooks/docs) using the standard TensorFlow 2.3 image.

### Dataset

The example dataset used in the blueprint is the BigQuery public dataset - `bigquery-samples.playlists` - that contains music playlist data, including a song name, a song artist, and the playlists a given song belongs to.

## Setting up the notebook's environment

### Import notebook dependencies

In [22]:
import datetime
import os
import json
import pandas as pd
import time


import google.auth
import numpy as np
import tensorflow.io as tf_io

from concurrent import futures
from google.cloud import bigquery
from typing import List, Optional, Text, Tuple

In [23]:
from IPython.display import clear_output

### Configure GCP environment

Set the following constants to the values reflecting your environment:

* `PROJECT_ID` - your GCP project ID
* `BUCKET_NAME` - a name of the bucket to store exported embeddings. If you prefer to use a pre-existing bucket you can skip the *Create GCS bucket* cell
* `BQ_LOCATION` - the BigQuery location
* `BQ_DATASET_NAME` - a name of the BigQuery dataset that will host training data and the model. If you prefer to use a pre-existing bucket you can skip the *Create BigQuery dataset* cell


In [24]:
PROJECT_ID = 'jk-mlops-dev'
BUCKET_NAME = 'jk-ann-staging'
REGION = 'us-central1'
BQ_LOCATION = 'US'
BQ_DATASET_NAME = 'song_embeddings'

#### Create GCS bucket

In [5]:
!gsutil mb -l {REGION} gs://{BUCKET_NAME}

Creating gs://jk-ann-staging/...


#### Create BigQuery dataset

In [17]:
client = bigquery.Client(project=PROJECT_ID, location=BQ_LOCATION)

In [8]:
dataset_id = f'{PROJECT_ID}.{BQ_DATASET_NAME}'

dataset = bigquery.Dataset(dataset_id)
dataset.location = BQ_LOCATION

try:
    client.get_dataset(dataset_id)
    print(f'Dataset {dataset_id} already exists')
except google.cloud.exceptions.NotFound:
    dataset = client.create_dataset(dataset, timeout=30)
    print(f'Created dataset: {dataset_id}')

Dataset jk-mlops-dev.song_embeddings already exists


## Preparing the training data

In this section of the notebook you will prepare data for training a Matrix Factorization model. It is a two step process. 

1. First, you will copy and clean records from the public `bigquery-samples.playlists.playlist` table to your dataset.
2. Then, you will compute item co-occurrence. In the context of the playlist data, item co-occurence quantifies how often two songs occur on the same playlists.

### Define a helper function to wait for a result of a BigQuery query

In [7]:
def wait_for_result(query_job):
    print("Executing query with job ID: {}".format(query_job.job_id))
    start_time = time.time()
    while True:
        print("\rQuery executing: {:0.2f}s".format(time.time() - start_time), end="")
        try:
            query_job.result(timeout=0.5)
            break
        except futures.TimeoutError:
            continue
    print("\nQuery complete after {:0.2f}s".format(time.time() - start_time))
    return query_job

### Clean the public `playlist` table and copy a subset to your dataset

The sample `playlist` dataset contains over 12 million records. Although the quality of embeddings grows with a size of a corpus used to train them, it takes a rather long time to complete training of a Matrix Factorization model on all records.

You can control the size of the corpus used for training the model by setting the `number_of_records` variable. To use all records set the variable to `None`.


In [25]:
number_of_records = 3000000

query = f"""
    CREATE OR REPLACE TABLE `{BQ_DATASET_NAME}.playlist`
    AS
    SELECT DISTINCT
        id list_Id, 
        tracks_data_id track_Id, 
        tracks_data_title track_title,
        tracks_data_artist_name track_artist
    FROM `bigquery-samples.playlists.playlist`
    WHERE tracks_data_title IS NOT NULL AND tracks_data_id > 0
"""

if  isinstance(number_of_records, int) and number_of_records <= 12304975 :
    limit_clause = f'LIMIT {number_of_records}'
    query = query + limit_clause
    
query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: c339145f-9269-4a33-aad9-013325bd0f87
Query executing: 16.01s
Query complete after 16.58s


### Explore the created table

In [8]:
query = f"""
    SELECT * 
    FROM
    `{BQ_DATASET_NAME}.playlist`
    LIMIT 10
"""

query_job = client.query(query)
query_job.to_dataframe()

Unnamed: 0,list_Id,track_Id,track_title,track_artist
0,9354888,3141579,Galvanize,The Chemical Brothers
1,8085136,2321655,California Dreamin',The Mamas & The Papas
2,2233197,5902134,Incense And Peppermints,Strawberry Alarm Clock
3,1174313,125479,"Mary Hartman, Mary Hartman",Television Theme Songs
4,884920,574738,Hit The Floor,Bullet for My Valentine
5,773635,912580,River In The Road,Queens of the Stone Age
6,9695428,93175,"String Quartet No. 2 in D major, K155: I. Allegro",Wolfgang Amadeus Mozart
7,7914373,1151554,"No Woman, No Cry (Live At The Lyceum, London/1...",Bob Marley & The Wailers
8,4604492,1151565,Jamming,Bob Marley & The Wailers
9,1407613,1151537,Iron Lion Zion,Bob Marley & The Wailers


### Create a view to abstract the `playlist` table

In [20]:
query = f"""
    CREATE OR REPLACE VIEW `{BQ_DATASET_NAME}.vw_item_groups`
    AS
    SELECT
      list_Id AS group_Id,
      track_Id AS item_Id
    FROM  
      `{BQ_DATASET_NAME}.playlist` 
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: afee00a5-601b-4dc5-a654-c9bb173df045
Query executing: 0.00s
Query complete after 0.34s


### Compute item co-occurence

You will now compute item co-occurence.

#### Create the stored procedure that encapsulates calculation logic

In [21]:
query = f"""
    CREATE OR REPLACE PROCEDURE {BQ_DATASET_NAME}.sp_ComputePMI(
    IN min_item_frequency INT64,
    IN max_group_size INT64
    )

    BEGIN

    DECLARE total INT64;

    # Get items with minimum frequency
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.valid_item_groups
    AS

    # Create valid item set
    WITH 
    valid_items AS (
        SELECT item_Id, COUNT(group_Id) AS item_frequency
        FROM {BQ_DATASET_NAME}.vw_item_groups
        GROUP BY item_Id
        HAVING item_frequency >= min_item_frequency
    ),

    # Create valid group set
    valid_groups AS (
        SELECT group_Id, COUNT(item_Id) AS group_size
        FROM {BQ_DATASET_NAME}.vw_item_groups
        WHERE item_Id IN (SELECT item_Id FROM valid_items)
        GROUP BY group_Id
        HAVING group_size BETWEEN 2 AND max_group_size
    )

    SELECT item_Id, group_Id
    FROM {BQ_DATASET_NAME}.vw_item_groups
    WHERE item_Id IN (SELECT item_Id FROM valid_items)
    AND group_Id IN (SELECT group_Id FROM valid_groups);

    # Compute pairwise cooc
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.item_cooc
    AS
    SELECT item1_Id, item2_Id, SUM(cooc) AS cooc
    FROM
    (
        SELECT
        a.item_Id item1_Id,
        b.item_Id item2_Id,
        1 as cooc
        FROM {BQ_DATASET_NAME}.valid_item_groups a
        JOIN {BQ_DATASET_NAME}.valid_item_groups b
        ON a.group_Id = b.group_Id
        AND a.item_Id < b.item_Id
    )
    GROUP BY  item1_Id, item2_Id;

    ###################################
    
    # Compute item frequencies
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.item_frequency
    AS
    SELECT item_Id, COUNT(group_Id) AS frequency
    FROM {BQ_DATASET_NAME}.valid_item_groups
    GROUP BY item_Id;

    ###################################
    
    # Compute total frequency |D|
    SET total = (
        SELECT SUM(frequency)  AS total
        FROM {BQ_DATASET_NAME}.item_frequency
    );
    
    ###################################
    
    # Add mirror item-pair cooc and same item frequency as cooc
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.item_cooc
    AS
    SELECT item1_Id, item2_Id, cooc
    FROM {BQ_DATASET_NAME}.item_cooc
    UNION ALL
    SELECT item2_Id as item1_Id, item1_Id AS item2_Id, cooc
    FROM {BQ_DATASET_NAME}.item_cooc
    UNION ALL
    SELECT item_Id as item1_Id, item_Id AS item2_Id, frequency as cooc
    FROM {BQ_DATASET_NAME}.item_frequency;

    ###################################
    
    # Compute PMI
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.item_cooc
    AS
    SELECT
        a.item1_Id,
        a.item2_Id,
        a.cooc,
        LOG(a.cooc, 2) - LOG(b.frequency, 2) - LOG(c.frequency, 2) + LOG(total, 2) AS pmi
    FROM {BQ_DATASET_NAME}.item_cooc a
    JOIN {BQ_DATASET_NAME}.item_frequency b
    ON a.item1_Id = b.item_Id
    JOIN {BQ_DATASET_NAME}.item_frequency c
    ON a.item2_Id = c.item_Id; 
    END
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: c8d94122-64a6-4b5e-b1a9-3a736eefca13
Query executing: 0.55s
Query complete after 1.14s


#### Execute the stored procedure

In [26]:
query = f"""
    DECLARE min_item_frequency INT64;
    DECLARE max_group_size INT64;

    SET min_item_frequency = 15;
    SET max_group_size = 100;

    CALL {BQ_DATASET_NAME}.sp_ComputePMI(min_item_frequency, max_group_size);
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: 1928be93-cb79-42e8-9c01-a0f69abc80b5
Query executing: 151.43s
Query complete after 151.99s


#### Explore the co-occurence data

In [9]:
query = f"""
    SELECT 
        a.item1_Id, 
        a.item2_Id, 
        b.frequency AS freq1,
        c.frequency AS freq2,
        a.cooc,
        a.pmi,
        a.cooc * a.pmi AS score
    FROM {BQ_DATASET_NAME}.item_cooc a
    JOIN {BQ_DATASET_NAME}.item_frequency b
    ON a.item1_Id = b.item_Id
    JOIN {BQ_DATASET_NAME}.item_frequency c 
    ON a.item2_Id = c.item_Id
    WHERE a.item1_Id != a.item2_Id
    ORDER BY score DESC
    LIMIT 10
"""

query_job = client.query(query)
query_job.to_dataframe()

Unnamed: 0,item1_Id,item2_Id,freq1,freq2,cooc,pmi,score
0,721785,790948,1003,927,182,8.602703,1565.691987
1,790948,721785,927,1003,182,8.602703,1565.691987
2,907016,907020,1159,1666,192,7.625565,1464.108474
3,907020,907016,1666,1159,192,7.625565,1464.108474
4,677232,676183,1313,1094,172,7.893657,1357.708924
5,676183,677232,1094,1313,172,7.893657,1357.708924
6,1581670,1581664,1025,818,144,8.414,1211.615968
7,1581664,1581670,818,1025,144,8.414,1211.615968
8,908984,955246,1348,1453,161,7.350933,1183.500231
9,955246,908984,1453,1348,161,7.350933,1183.500231


## Training Matrix Factorization Model

### Create a stored procedure that encapsulates the training statement

In [27]:
query = f"""
    CREATE OR REPLACE PROCEDURE {BQ_DATASET_NAME}.sp_TrainItemMatchingModel(
        IN dimensions INT64
    )

    BEGIN

    CREATE OR REPLACE MODEL {BQ_DATASET_NAME}.item_matching_model
    OPTIONS(
        MODEL_TYPE='matrix_factorization', 
        FEEDBACK_TYPE='implicit',
        WALS_ALPHA=1,
        NUM_FACTORS=(dimensions),
        USER_COL='item1_Id', 
        ITEM_COL='item2_Id',
        RATING_COL='score',
        DATA_SPLIT_METHOD='no_split'
    )
    AS
    SELECT 
        item1_Id, 
        item2_Id, 
        cooc * pmi AS score
    FROM {BQ_DATASET_NAME}.item_cooc;
    END
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: fd874318-327b-4e8f-a12f-d45c6daa101f
Query executing: 0.54s
Query complete after 1.11s


### Start training

Be patient - training can take a while

In [28]:
query = f"""
    DECLARE dimensions INT64 DEFAULT 50;
    CALL {BQ_DATASET_NAME}.sp_TrainItemMatchingModel(dimensions)
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: 18e6f6b6-2bba-4a88-8b32-8d54d0c01846
Query executing: 7260.71s
Query complete after 7260.78s


## Exporting the trained embeddings

### Create a stored procedure that extracts embeddings from the model

In [10]:
query = f"""
    CREATE OR REPLACE PROCEDURE {BQ_DATASET_NAME}.sp_ExractEmbeddings() 
    BEGIN
    CREATE OR REPLACE TABLE {BQ_DATASET_NAME}.item_embeddings AS
        WITH 
        step1 AS
        (
            SELECT 
                feature AS item_Id,
                factor_weights
            FROM
                ML.WEIGHTS(MODEL `{BQ_DATASET_NAME}..item_matching_model`)
            WHERE feature != 'global__INTERCEPT__'
        ),

        step2 AS
        (
            SELECT 
                item_Id, 
                factor, 
                SUM(weight) AS weight
            FROM step1,
            UNNEST(step1.factor_weights) AS embedding
            GROUP BY 
            item_Id,
            factor 
        )

        SELECT 
            item_Id as id, 
            ARRAY_AGG(weight ORDER BY factor ASC) embedding,
        FROM step2
        GROUP BY item_Id;
    END 
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: 6d1c238e-120b-4664-8422-e450d2d14883
Query executing: 0.57s
Query complete after 1.20s


### Extract embeddings to a table

In [18]:
query = f"""
CALL {BQ_DATASET_NAME}.sp_ExractEmbeddings()
"""

query_job = client.query(query)
query_job = wait_for_result(query_job)

Executing query with job ID: 97e3d9e7-02ac-4ffb-ba29-70f398324aaf
Query executing: 7.34s
Query complete after 7.90s


#### Verify the number of embeddings

In [20]:
query = f"""
    SELECT COUNT(*) embedding_count
    FROM {BQ_DATASET_NAME}.item_embeddings;
"""

query_job = client.query(query)
query_job.to_dataframe()

Unnamed: 0,embedding_count
0,5000


### Export embeddings to GCS

In [25]:
file_name_pattern = 'embedding-*.json'
destination_uri = f'gs://{BUCKET_NAME}/embeddings/{file_name_pattern}'
table_id = 'item_embeddings'
location = BQ_LOCATION
destination_format = 'NEWLINE_DELIMITED_JSON'

client = bigquery.Client()
dataset_ref = bigquery.DatasetReference(PROJECT_ID, BQ_DATASET_NAME)
table_ref = dataset_ref.table(table_id)
job_config = bigquery.job.ExtractJobConfig()
job_config.destination_format = bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON

extract_job = client.extract_table(
    table_ref,
    destination_uris=destination_uri,
    job_config=job_config,
    location=location,
)  # API request

extract_job = wait_for_result(extract_job)

Executing query with job ID: a38a263e-9ea6-400a-805c-81afadc6d551
Query executing: 10.04s
Query complete after 10.60s


## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**