In [1]:
# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions anda
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.

<img src="https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models-transformers-net-item-prediction/nvidia_logo.png" style="width: 90px; float: right;">

# Transformer-based architecture for next-item prediction task with pretrained embeddings

This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container.

## Overview

In this use case we will train a Transformer-based architecture for next-item prediction task with pretrained embeddings.

**You can chose to download the full dataset manually or use synthetic data.**

We will use the [booking.com dataset](https://github.com/bookingcom/ml-dataset-mdt) to train a session-based model. The dataset contains 1,166,835 of anonymized hotel reservations in the train set and 378,667 in the test set. Each reservation is a part of a customer's trip (identified by `utrip_id`) which includes consecutive reservations.

We will reshape the data to organize it into 'sessions'. Each session will be a full customer itinerary in chronological order. The goal will be to predict the city_id of the final reservation of each trip.


### Learning objectives

- Training a Transformer-based architecture for next-item prediction task

## Downloading and preparing the dataset

You can download the full dataset by registering [here](https://www.coveo.com/en/ailabs/sigir-ecom-data-challenge). If you chose to download the data, please place in alongside this notebook in the `data` directory (you might have to create it).

To process the downloaded data uncomment the cell below.

In [16]:
import nvtabular as nvt

train = nvt.Dataset('/workspace/sigir_dataset/train/browsing_train.csv', part_size='500MB')
skus = nvt.Dataset('/workspace/sigir_dataset/train/sku_to_content.csv')

In [17]:
out = ['session_id_hash'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
out += ['event_type', 'product_action', 'product_sku_hash', 'hashed_url'] >> nvt.ops.Categorify()
out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)

train = wf.fit_transform(train)

In [18]:
out = ['product_sku_hash'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
out += ['description_vector', 'image_vector'] >> nvt.ops.TagAsItemFeatures()
out += ['price_bucket'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)
skus = wf.fit_transform(skus)

In [19]:
train.head()

Unnamed: 0,session_id_hash,event_type,product_action,product_sku_hash,hashed_url,server_timestamp_epoch_ms
0,66853,4,3,55386,376,0.431877
1,66853,4,3,25548,197,0.431877
2,66853,3,1,1,197,0.431877
3,66853,4,3,55386,376,0.431877
4,66853,3,1,1,376,0.431877


In [20]:
skus.head()

Unnamed: 0,product_sku_hash,description_vector,image_vector,price_bucket
0,10225,,,
1,25976,,,
2,42474,"[0.27629122138023376, -0.15763211250305176, 0....","[340.3592564184389, -220.19025864725685, 154.0...",0.666666667
3,4204,"[0.4058118760585785, -0.03595402091741562, 0.2...","[180.3463662921092, 222.702322343354, -8.88703...",0.777777778
4,39731,"[-0.3206155300140381, 0.01991105079650879, 0.0...","[-114.81079301576219, 84.55770104232334, 85.51...",0.111111111


To use synthetically generated data, uncomment the following cell:

In [1]:
%%bash

cd /workspace && pip install . 

Processing /workspace
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'


Building wheels for collected packages: merlin-models
  Building wheel for merlin-models (PEP 517): started
  Building wheel for merlin-models (PEP 517): finished with status 'done'
  Created wheel for merlin-models: filename=merlin_models-23.5.dev0+39.g13c11411.dirty-py3-none-any.whl size=424028 sha256=71ff8361ab8a3b94e2c5ccfc05228dbdd11e5d1709177faca6f0ee7e68191e10
  Stored in directory: /tmp/pip-ephem-wheel-cache-72fp30go/wheels/59/14/70/d94958f41745fe226f3bc60bb3cabbbc8a98e4d6679e91038a
Successfully built merlin-models


ERROR: transformers4rec 23.5.0+2.g8a44ce43 requires torchmetrics>=0.10.0, which is not installed.


Installing collected packages: merlin-models
  Attempting uninstall: merlin-models
    Found existing installation: merlin-models 0+unknown
    Not uninstalling merlin-models at /workspace, outside environment /usr
    Can't uninstall 'merlin-models'. No files were found to uninstall.
Successfully installed merlin-models-23.5.dev0+39.g13c11411.dirty


In [2]:
from merlin.datasets.synthetic import KNOWN_DATASETS

2023-06-02 11:37:33.672226: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")


In [3]:
KNOWN_DATASETS

{'e-commerce': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/ecommerce/small'),
 'e-commerce-large': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/ecommerce/large'),
 'music-streaming': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/music_streaming'),
 'social': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/social'),
 'testing': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/testing'),
 'sequence-testing': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/testing/sequence_testing'),
 'movielens-25m': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/25m'),
 'movielens-1m': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/1m'),
 'movielens-1m-raw-ratings': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/1m-raw/ratings'),
 'movielens-100k': 

In [4]:
from merlin.datasets.synthetic import generate_data

generate_data('sigir-browsing', 1000).head()

Unnamed: 0,session_id_hash,event_type,product_action,product_sku_hash,hashed_url,server_timestamp_epoch_ms
0,76,0,4,854,277,0.320487
1,27,2,2,507,57,0.074733
2,8,3,3,423,654,0.689865
3,44,0,0,297,494,0.797353
4,20,2,5,128,923,0.477495


In [5]:
generate_data('sigir-sku', 1000).head()

Unnamed: 0,product_sku_hash,description_vector,image_vector,price_bucket
0,43,"[-0.10591142379148188, -0.03795119909217676, 0...","[154.24838911819757, -220.94911488369894, 706....",0.105784
1,11,"[0.39833330585464594, 0.40623340973191585, -0....","[447.34390006248134, 111.04788008656772, 41.86...",0.31185
2,11,"[0.26045993213506796, 0.14917678863577694, 0.4...","[286.8398628984362, 719.1353210910496, -169.89...",0.50959
3,42,"[-0.17769888574348358, 0.25437945800846457, -0...","[220.39659322932948, 305.26070651703776, 157.6...",0.073677
4,10,"[0.11854471618404411, 0.2769747120238893, 0.19...","[620.6735804239985, 529.3302886904164, 222.665...",0.531253
