In [1]:
# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions anda
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.

<img src="https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models-transformers-net-item-prediction/nvidia_logo.png" style="width: 90px; float: right;">

# Transformer-based architecture for next-item prediction task with pretrained embeddings

This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container.

## Overview

In this use case we will train a Transformer-based architecture for next-item prediction task with pretrained embeddings.

**You can chose to download the full dataset manually or use synthetic data.**

We will use the [booking.com dataset](https://github.com/bookingcom/ml-dataset-mdt) to train a session-based model. The dataset contains 1,166,835 of anonymized hotel reservations in the train set and 378,667 in the test set. Each reservation is a part of a customer's trip (identified by `utrip_id`) which includes consecutive reservations.

We will reshape the data to organize it into 'sessions'. Each session will be a full customer itinerary in chronological order. The goal will be to predict the city_id of the final reservation of each trip.


### Learning objectives

- Training a Transformer-based architecture for next-item prediction task

## Downloading and preparing the dataset

You can download the full dataset by registering [here](https://www.coveo.com/en/ailabs/sigir-ecom-data-challenge). If you chose to download the data, please place in alongside this notebook in the `data` directory (you might have to create it).

To process the downloaded data uncomment the cell below.

In [2]:
import nvtabular as nvt

train = nvt.Dataset('/workspace/sigir_dataset/train/browsing_train.csv', part_size='500MB')
skus = nvt.Dataset('/workspace/sigir_dataset/train/sku_to_content.csv')

2023-06-06 00:11:43.482404: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")


In [3]:
out = ['session_id_hash'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
out += ['event_type', 'product_action', 'product_sku_hash', 'hashed_url'] >> nvt.ops.Categorify()
out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)

train = wf.fit_transform(train)



The data contains `image_vector` information which we won't be using and hence we don't include it in the workflow below.

In [4]:
out = ['product_sku_hash', 'category_hash'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
out += ['description_vector'] >> nvt.ops.TagAsItemFeatures()
out += ['price_bucket'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)
skus = wf.fit_transform(skus)



In [5]:
train.head()

Unnamed: 0,session_id_hash,event_type,product_action,product_sku_hash,hashed_url,server_timestamp_epoch_ms
0,66851,2,1,55384,374,0.431877
1,66851,2,1,25546,195,0.431877
2,66851,1,0,0,195,0.431877
3,66851,2,1,55384,374,0.431877
4,66851,1,0,0,374,0.431877


In [6]:
skus.head()

Unnamed: 0,product_sku_hash,category_hash,description_vector,price_bucket
0,10223,0,,
1,25974,0,,
2,42472,10,"[0.27629122138023376, -0.15763211250305176, 0....",0.666666667
3,4202,109,"[0.4058118760585785, -0.03595402091741562, 0.2...",0.777777778
4,39729,1,"[-0.3206155300140381, 0.01991105079650879, 0.0...",0.111111111


To use synthetically generated data, uncomment the following cell:

In [7]:
%%bash

cd /workspace && pip install . 

Processing /workspace
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'


Building wheels for collected packages: merlin-models
  Building wheel for merlin-models (PEP 517): started
  Building wheel for merlin-models (PEP 517): finished with status 'done'
  Created wheel for merlin-models: filename=merlin_models-23.5.dev0+40.g1e01e265.dirty-py3-none-any.whl size=424050 sha256=49fd364c59d55aa01363ad8602f3daf56064f5868cf2bff54373d71ca4d77291
  Stored in directory: /tmp/pip-ephem-wheel-cache-dkiam1rd/wheels/59/14/70/d94958f41745fe226f3bc60bb3cabbbc8a98e4d6679e91038a
Successfully built merlin-models
Installing collected packages: merlin-models
  Attempting uninstall: merlin-models
    Found existing installation: merlin-models 0+unknown
    Can't uninstall 'merlin-models'. No files were found to uninstall.
Successfully installed merlin-models-23.5.dev0+40.g1e01e265.dirty


In [8]:
from merlin.datasets.synthetic import KNOWN_DATASETS

In [9]:
KNOWN_DATASETS

{'e-commerce': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/ecommerce/small'),
 'e-commerce-large': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/ecommerce/large'),
 'music-streaming': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/music_streaming'),
 'social': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/social'),
 'testing': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/testing'),
 'sequence-testing': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/testing/sequence_testing'),
 'movielens-25m': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/25m'),
 'movielens-1m': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/1m'),
 'movielens-1m-raw-ratings': PosixPath('/usr/local/lib/python3.8/dist-packages/merlin/datasets/entertainment/movielens/1m-raw/ratings'),
 'movielens-100k': 

In [10]:
from merlin.datasets.synthetic import generate_data

generate_data('sigir-browsing', 1000).head()



Unnamed: 0,session_id_hash,event_type,product_action,product_sku_hash,hashed_url,server_timestamp_epoch_ms
0,2,0,2,324,468,0.151202
1,11,0,3,757,492,0.14985
2,7,0,1,889,126,0.865948
3,6,1,1,930,693,0.869109
4,32,0,0,318,480,0.617824


In [11]:
generate_data('sigir-sku', 1000).head()



Unnamed: 0,category_hash,product_sku_hash,description_vector,price_bucket
0,28,156,"[0.19691553495989195, -0.1834484775978868, 0.5...",0.290628
1,15,81,"[-0.051316745930657215, -0.02616168732559826, ...",0.692825
2,3,12,"[0.5836391698545325, -0.248344824994722, -0.43...",0.771884
3,35,197,"[0.20024049029083574, -0.319046437918984, 0.01...",0.725247
4,57,324,"[0.1901695010328885, 0.19299795066133935, -0.3...",0.484631
