In [1]:
# %%bash
# cd /core
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

# cd /dataloader
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

# cd /nvtabular
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

# cd /models
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

# cd /systems
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

# cd /transformers4rec
# git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main
# pip install . --no-deps

In [2]:
# Copyright 2022 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions anda
# limitations under the License.
# ==============================================================================

# Each user is responsible for checking the content of datasets and the
# applicable licenses and determining if suitable for the intended use.

<img src="https://developer.download.nvidia.com/notebooks/dlsw-notebooks/merlin_models-transformers-net-item-prediction/nvidia_logo.png" style="width: 90px; float: right;">

# Transformer-based architecture for next-item prediction task with pretrained embeddings

This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container.

## Overview

In this use case we will train a Transformer-based architecture for next-item prediction task with pretrained embeddings.

**You can chose to download the full dataset manually or use synthetic data.**

We will use the [booking.com dataset](https://github.com/bookingcom/ml-dataset-mdt) to train a session-based model. The dataset contains 1,166,835 of anonymized hotel reservations in the train set and 378,667 in the test set. Each reservation is a part of a customer's trip (identified by `utrip_id`) which includes consecutive reservations.

We will reshape the data to organize it into 'sessions'. Each session will be a full customer itinerary in chronological order. The goal will be to predict the city_id of the final reservation of each trip.


### Learning objectives

- Training a Transformer-based architecture for next-item prediction task

## Downloading and preparing the dataset

In [3]:
import nvtabular as nvt
import cudf
import numpy as np

2023-06-12 00:45:47.445300: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  warn(f"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}")
2023-06-12 00:45:48.786568: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-12 00:45:48.786958: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-12 00:45:48.787112: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executo

You can download the full dataset by registering [here](https://www.coveo.com/en/ailabs/sigir-ecom-data-challenge). If you chose to download the data, please place it alongside this notebook in the `sigir_dataset` directory and extract it.

To process the downloaded data uncomment the cell below.

Otherwise, if you'd prefer to use synthetically generated data, uncomment the second cell below.

In [4]:
# Unocomment this cell to use the SIGIR dataset.

import pandas as pd

train = nvt.Dataset('/workspace/sigir_dataset/train/browsing_train.csv', part_size='500MB')
skus = nvt.Dataset('/workspace/sigir_dataset/train/sku_to_content.csv')

skus = pd.read_csv('/workspace/sigir_dataset/train/sku_to_content.csv')

skus['description_vector'] = skus['description_vector'].replace(np.nan, '')
skus['image_vector'] = skus['image_vector'].replace(np.nan, '')

skus['description_vector'] = skus['description_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))
skus['image_vector'] = skus['image_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))

In [5]:
# df = pd.read_csv('/workspace/sigir_dataset/train/browsing_train.csv')

# df['product_sku_hash'].isna().mean()

# df['hashed_url'].isna().mean()

In [6]:
cat_op = nvt.ops.Categorify()
out = ['product_sku_hash'] >> cat_op >> nvt.ops.TagAsItemID()
out += ['event_type', 'product_action', 'session_id_hash', 'hashed_url'] >> nvt.ops.Categorify()
out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)

train = wf.fit_transform(train)

In [7]:
from merlin.schema import ColumnSchema, Schema, Tags

In [8]:
# cat_op = nvt.ops.Categorify()
# out = ['product_sku_hash'] >> cat_op >> nvt.ops.TagAsItemID()
# out += ['event_type', 'product_action', 'session_id_hash'] >> nvt.ops.Categorify()
# out += ['hashed_url'] >> nvt.ops.Categorify() >> nvt.ops.AddMetadata(tags=Tags.TARGET)
# out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()

# wf = nvt.Workflow(out)

# train = wf.fit_transform(train)

The `skus` dataset contains the mapping between the `product_sku_hash` (essentially an item id) to the `description_vector` -- an embedding obtained from the description.

To use this information in our model, we need to map the `product_sku_hash` information to an id.

But we need to make sure that the way we process `skus` and the `train` dataset (event information) is consistent. That the same `product_sku_hash` is mapped to the same id both when processing `skus` and `train`.

We do so by defining and fitting a `Categorify` op and using it to process both datasets.

Now that we have processed the train set, we can use the mapping preserved in the `cat_op` to process the `skus` dataset containing the embeddings we are after.

Let's now `Categorify` the `product_sku_hash` in `skus` and grab just the description embedding information.

In [9]:
skus.head()

Unnamed: 0,product_sku_hash,description_vector,category_hash,image_vector,price_bucket
0,26ce7b47f4c46e4087e83e54d2f7ddc7ea57862fed2e2a...,[],,[],
1,6383992be772b204a9ab75f86c86f5583d1bdd1222952d...,[],,[],
2,a2c3e2430c6ef9770b903ad08fa067a6b2b9db28f06e1b...,"[0.27629122138023376, -0.15763211250305176, 0....",06fa312761d4b39e2f649781514ac69a4c1505c221fc46...,"[340.3592564184389, -220.19025864725685, 154.0...",7.0
3,1028ef615e425c328e7b95010dfb1fb93cf63749a1bc80...,"[0.4058118760585785, -0.03595402091741562, 0.2...",115a6a7017ee55752b8487c77dfde92b0d501d10a2e69c...,"[180.3463662921092, 222.702322343354, -8.88703...",8.0
4,9870c682d0d52d635501249da0eeaa118fad430b695ea1...,"[-0.3206155300140381, 0.01991105079650879, 0.0...",0665a81d19c89281cc00e7f7d779ded2ed42c933838602...,"[-114.81079301576219, 84.55770104232334, 85.51...",2.0


The data contains `image_vector` information which we won't be using and hence we don't include it in the workflow below.

In [10]:
skus = skus[skus.description_vector.apply(len) > 0]

In [11]:
out = ['product_sku_hash'] >> cat_op
wf = nvt.Workflow(out + 'description_vector')
skus_ds = wf.transform(nvt.Dataset(skus))

In [12]:
skus_ds.head()

Unnamed: 0,product_sku_hash,description_vector
0,6207,"[0.27629122138023376, -0.15763211250305176, 0...."
1,7691,"[0.4058118760585785, -0.03595402091741562, 0.2..."
2,10812,"[-0.3206155300140381, 0.01991105079650879, 0.0..."
3,21238,"[-0.1854386031627655, 0.19424490630626678, -0...."
4,8776,"[-0.24601778388023376, -0.12155783176422119, -..."


In [13]:
skus_ds.to_npy('skus.npy')

In [14]:
np.load('skus.npy')[:5, 0]

array([ 6207.,  7691., 10812., 21238.,  8776.])

In [15]:
from merlin.dataloader.tensorflow import Loader

In [16]:
embeddings = np.load('skus.npy')

In [57]:
groupby_features = train.head().columns.tolist() >> nvt.ops.Groupby(
    groupby_cols=['session_id_hash'],
    aggs={
        'product_sku_hash': ['list'],
        'event_type': ['list'],
        'product_action': ['list'],
        'hashed_url': ['list', 'count'],
        'server_timestamp_epoch_ms': ['list']
    },
    sort_cols="server_timestamp_epoch_ms"
)

MINIMUM_SESSION_LENGTH = 5
filtered_sessions = groupby_features >> nvt.ops.Filter(f=lambda df: df["hashed_url_count"] >= MINIMUM_SESSION_LENGTH) 

In [58]:
wf = nvt.Workflow(filtered_sessions)
train_processed = wf.fit_transform(train)

In [59]:
train_processed.head()

Unnamed: 0,session_id_hash,product_sku_hash_list,event_type_list,product_action_list,hashed_url_list,server_timestamp_epoch_ms_list,hashed_url_count
0,19,"[1, 3, 1, 17792, 3, 3, 1, 3, 3, 3, 1, 1, 1, 1,...","[3, 4, 3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 3, 3, 3, ...","[1, 5, 1, 5, 5, 5, 1, 5, 5, 5, 1, 1, 1, 1, 1, ...","[3, 5, 5, 183199, 183199, 5, 5, 157277, 157277...","[0.017701422675646814, 0.017701613740402588, 0...",200
1,28,"[49, 1, 1, 1, 1, 2779, 1, 1, 1, 1, 1, 1, 1, 1,...","[4, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[126, 126, 19, 11, 4241, 3689, 3689, 4241, 695...","[0.08774476140013235, 0.08774476140013235, 0.0...",200
2,31,"[1, 1, 1, 1, 381, 1, 1, 1, 1, 1, 1, 1, 1, 1086...","[3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, ...","[1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, ...","[3, 77, 3802, 663, 663, 3802, 77, 77, 95, 77, ...","[0.41732644544485625, 0.4173890392881526, 0.41...",200
3,47,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 297, 1, 413, 1,...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 3, ...","[5648, 2648, 2648, 13, 53, 14323, 22235, 12033...","[0.7742340701027528, 0.7742346497301769, 0.774...",200
4,73,"[1, 1, 3287, 1, 1, 2349, 3287, 1, 1, 2349, 328...","[3, 3, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 3, ...","[1, 1, 3, 1, 1, 3, 3, 1, 1, 3, 3, 1, 1, 3, 1, ...","[40, 13, 4411, 4411, 3075, 3075, 4411, 4411, 3...","[0.3928006023156755, 0.3928008996561809, 0.392...",200


In [20]:
# df = cudf.DataFrame(data={'id': [0, 0, 1, 2], 'cat': list(range(4)), 'val': [10, 20, 20, 30]})

# ds = nvt.Dataset(df)

# cat = ['cat'] >> nvt.ops.Categorify()
# groupby_features = cat + ['id', 'val'] >> nvt.ops.Groupby(
#     groupby_cols=['id'],
#     aggs={
#         'cat': ['list', 'count'],
#         'val': ['list']
#     }
# )

# wf = nvt.Workflow(groupby_features)
# ds_transformed = wf.fit_transform(ds)

In [60]:
from merlin.dataloader.ops.embeddings import EmbeddingOperator

In [61]:
embeddings

array([[ 6.20700000e+03,  2.76291221e-01, -1.57632113e-01, ...,
         1.37491841e-02, -6.54154569e-02,  2.93405820e-02],
       [ 7.69100000e+03,  4.05811876e-01, -3.59540209e-02, ...,
        -5.11377156e-02,  1.23936273e-02, -3.28034014e-02],
       [ 1.08120000e+04, -3.20615530e-01,  1.99110508e-02, ...,
         4.23921552e-03,  1.48407519e-02, -8.13788921e-03],
       ...,
       [ 2.51510000e+04, -1.71209499e-01,  1.72444582e-01, ...,
         3.26483995e-02, -1.27059463e-02, -2.08097659e-02],
       [ 2.00000000e+00, -1.91504419e-01, -6.23516217e-02, ...,
        -2.50775218e-02, -2.31876411e-02,  4.92795147e-02],
       [ 2.87040000e+04, -1.97609365e-01,  4.44645047e-01, ...,
         4.00494635e-02,  3.62532213e-03,  4.55308110e-02]])

In [66]:
train_processed.schema = train_processed.schema.remove_col('session_id_hash').remove_col('hashed_url_count')

In [67]:
loader = Loader(
    train_processed,
    batch_size=10,
    transforms=[
        EmbeddingOperator(
            embeddings[:, 1:],
            id_lookup_table=embeddings[:, 0].astype(int),
            lookup_key="product_sku_hash_list",
        )
    ],
    shuffle=True
)



In [68]:
train_processed.head()

Unnamed: 0,session_id_hash,product_sku_hash_list,event_type_list,product_action_list,hashed_url_list,server_timestamp_epoch_ms_list,hashed_url_count
0,19,"[1, 3, 1, 17792, 3, 3, 1, 3, 3, 3, 1, 1, 1, 1,...","[3, 4, 3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 3, 3, 3, ...","[1, 5, 1, 5, 5, 5, 1, 5, 5, 5, 1, 1, 1, 1, 1, ...","[3, 5, 5, 183199, 183199, 5, 5, 157277, 157277...","[0.017701422675646814, 0.017701613740402588, 0...",200
1,28,"[49, 1, 1, 1, 1, 2779, 1, 1, 1, 1, 1, 1, 1, 1,...","[4, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[3, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[126, 126, 19, 11, 4241, 3689, 3689, 4241, 695...","[0.08774476140013235, 0.08774476140013235, 0.0...",200
2,31,"[1, 1, 1, 1, 381, 1, 1, 1, 1, 1, 1, 1, 1, 1086...","[3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, ...","[1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, ...","[3, 77, 3802, 663, 663, 3802, 77, 77, 95, 77, ...","[0.41732644544485625, 0.4173890392881526, 0.41...",200
3,47,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 297, 1, 413, 1,...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 3, ...","[5648, 2648, 2648, 13, 53, 14323, 22235, 12033...","[0.7742340701027528, 0.7742346497301769, 0.774...",200
4,73,"[1, 1, 3287, 1, 1, 2349, 3287, 1, 1, 2349, 328...","[3, 3, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 3, ...","[1, 1, 3, 1, 1, 3, 3, 1, 1, 3, 3, 1, 1, 3, 1, ...","[40, 13, 4411, 4411, 3075, 3075, 4411, 4411, 3...","[0.3928006023156755, 0.3928008996561809, 0.392...",200


In [69]:
loader.peek()

({'product_sku_hash_list__values': <tf.Tensor: shape=(159,), dtype=int64, numpy=
  array([    1,     1,  2662,     1,  2470,     1,     1,  2662,     1,
             1,  6193,     1,     1,  7100,     1,     1,  5529,     1,
             1,     1,     1,     1,  3007,     1,     1,  6492,     1,
          5657,  6492,     1,     1,     1,     1, 10428,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,  1185,     1,   124,
             1,     1,   114,  1121,     1,   414,     1,   473,     1,
             1,  1167,     1,  1270,     1,   473,     1,    98,     1,
           334,     1,   669,     1,  3934,   669,     1,     1,    98,
           146,     1,     1,    98,   669,     1,   334,     1,   669,
             1,  2824,     1,   669,     1,   146,     1,     1,   669,
             1,   334,     1,   669,   334,     1,     1,   669,   146,
             1,   669,     1,     1,    98,     1,   66

In [70]:
import merlin.models.tf as mm

In [71]:
schema = loader.output_schema

In [28]:
# schema = schema.remove_col('session_id_hash')
# schema = schema.remove_col('product_sku_hash_count')

In [72]:
schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.num_buckets,properties.freq_threshold,properties.max_size,properties.cat_path,properties.domain.min,properties.domain.max,properties.domain.name,properties.embedding_sizes.cardinality,properties.embedding_sizes.dimension,properties.value_count.min,properties.value_count.max
0,product_sku_hash_list,"(Tags.ITEM, Tags.CATEGORICAL, Tags.ID)","DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_sku_hash.parquet,0.0,57485.0,product_sku_hash,57486.0,512.0,0,
1,event_type_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.event_type.parquet,0.0,4.0,event_type,5.0,16.0,0,
2,product_action_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_action.parquet,0.0,6.0,product_action,7.0,16.0,0,
3,hashed_url_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.hashed_url.parquet,0.0,489302.0,hashed_url,489303.0,512.0,0,
4,server_timestamp_epoch_ms_list,(Tags.CONTINUOUS),"DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,0,
5,embeddings,"(Tags.ITEM, Tags.EMBEDDING)","DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,0,


In [73]:
# if broadcast_non_seq_features:
#     # TODO: Check if it would be possible for EmbeddingOperator to keep the Tags.SEQUENCE
#     # in the output embedding schema if lookup_key has that tag
#     seq_schema = schema.select_by_tag(Tags.SEQUENCE) + schema.select_by_name(
#         "pretrained_item_id_embeddings"
#     )
#     non_seq_schema = schema.select_by_name(["user_country", "pretrained_user_id_embeddings"])
#     input_kwargs = {"post": mm.BroadcastToSequence(non_seq_schema, seq_schema)}

input_block = mm.InputBlockV2(
    schema,
    embeddings=mm.Embeddings(
        schema.select_by_tag(Tags.CATEGORICAL),
        sequence_combiner=None,
    ),
    pretrained_embeddings=mm.PretrainedEmbeddings(
        schema.select_by_tag(Tags.EMBEDDING),
        sequence_combiner=None,
        normalizer="l2-norm",
        output_dims={"embeddings": 128},
    )
#     post=mm.BroadcastToSequence(non_seq_schema, seq_schema)
)

In [74]:
loader.input_schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.num_buckets,properties.freq_threshold,properties.max_size,properties.cat_path,properties.domain.min,properties.domain.max,properties.domain.name,properties.embedding_sizes.cardinality,properties.embedding_sizes.dimension,properties.value_count.min,properties.value_count.max
0,product_sku_hash_list,"(Tags.ITEM, Tags.CATEGORICAL, Tags.ID)","DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_sku_hash.parquet,0.0,57485.0,product_sku_hash,57486.0,512.0,0,
1,event_type_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.event_type.parquet,0.0,4.0,event_type,5.0,16.0,0,
2,product_action_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_action.parquet,0.0,6.0,product_action,7.0,16.0,0,
3,hashed_url_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.hashed_url.parquet,0.0,489302.0,hashed_url,489303.0,512.0,0,
4,server_timestamp_epoch_ms_list,(Tags.CONTINUOUS),"DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,0,


In [75]:
loader.output_schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.num_buckets,properties.freq_threshold,properties.max_size,properties.cat_path,properties.domain.min,properties.domain.max,properties.domain.name,properties.embedding_sizes.cardinality,properties.embedding_sizes.dimension,properties.value_count.min,properties.value_count.max
0,product_sku_hash_list,"(Tags.ITEM, Tags.CATEGORICAL, Tags.ID)","DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_sku_hash.parquet,0.0,57485.0,product_sku_hash,57486.0,512.0,0,
1,event_type_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.event_type.parquet,0.0,4.0,event_type,5.0,16.0,0,
2,product_action_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.product_action.parquet,0.0,6.0,product_action,7.0,16.0,0,
3,hashed_url_list,(Tags.CATEGORICAL),"DType(name='int64', element_type=<ElementType....",True,True,,0.0,0.0,.//categories/unique.hashed_url.parquet,0.0,489302.0,hashed_url,489303.0,512.0,0,
4,server_timestamp_epoch_ms_list,(Tags.CONTINUOUS),"DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,0,
5,embeddings,"(Tags.ITEM, Tags.EMBEDDING)","DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,0,


In [76]:
inputs = mm.sample_batch(loader, batch_size=10, include_targets=False, prepare_features=True)
input_batch = input_block(inputs)

In [77]:
input_batch.shape

TensorShape([10, None, 233])

In [78]:
target = 'hashed_url_list'

In [79]:
dmodel=128
mlp_block = mm.MLPBlock(
                [128,dmodel],
                activation='relu',
                no_activation_last_layer=True,
            )
transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)
model = mm.Model(
    input_block,
    mlp_block,
    transformer_block,
    mm.CategoricalOutput(
        train_processed.schema.select_by_name(target),
        default_loss="categorical_crossentropy",
    ),
)

In [80]:
model.compile(run_eagerly=False, optimizer='adam', loss="categorical_crossentropy")
model.fit(loader, batch_size=64, epochs=5, pre=mm.SequenceMaskRandom(schema=schema, target=target, masking_prob=0.3, transformer=transformer_block))

2023-06-12 01:00:02.308265: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8700


Epoch 1/5


2023-06-12 01:00:24.143373: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model_2/xl_net_block_2/sequential_block_21/replace_masked_embeddings_4/RaggedWhere/Assert/AssertGuard/branch_executed/_111


   996/188411 [..............................] - ETA: 2:21:59 - loss: 10.2552 - recall_at_10: 0.1281 - mrr_at_10: 0.0678 - ndcg_at_10: 0.0822 - map_at_10: 0.0678 - precision_at_10: 0.0128 - regularization_loss: 0.0000e+00 - loss_batch: 10.2335

KeyboardInterrupt: 

In [None]:
    embeddings = np.load(npy_path)
    # second workflow that categorifies the embedding table data
    df = make_df({"string_id": np.random.choice(string_ids, 30)})
    graph2 = ["string_id"] >> cat_op
    train_res = Workflow(graph2).transform(Dataset(df, cpu=(cpu is not None)))

    data_loader = Loader(
        train_res,
        batch_size=1,
        transforms=[
            EmbeddingOperator(
                embeddings[:, 1:],
                id_lookup_table=embeddings[:, 0].astype(int),
                lookup_key="string_id",
            )
        ],
        shuffle=False,
        device=cpu,
    )
    origin_df = train_res.to_ddf().merge(emb_res.to_ddf(), on="string_id", how="left").compute()
    for idx, batch in enumerate(data_loader):
        batch
        b_df = batch[0].to_df()
        org_df = origin_df.iloc[idx]
        if not cpu:
            assert (b_df["string_id"].to_numpy() == org_df["string_id"].to_numpy()).all()
            assert (b_df["embeddings"].list.leaves == org_df["embeddings"].list.leaves).all()
        else:
            assert (b_df["string_id"].values == org_df["string_id"]).all()
            assert b_df["embeddings"].values[0] == org_df["embeddings"].tolist()

In [None]:
import cudf
import numpy as np

In [None]:
df = cudf.DataFrame(data={'id': [1,2,3], 'val': [0, np.nan, 10], 'another_col': ['a', 'b', 'c']})

In [None]:
df.val[df.val.isna()]

In [None]:
df.val[~df.val.isna()]

In [None]:
ds = nvt.Dataset(df)

out = ['val'] >> nvt.ops.Filter(f=lambda col: ~col.isna())

wf = nvt.Workflow(out)
wf.fit_transform(ds).compute()

In [None]:
out = ['id', 'val'] >> nvt.ops.Filter(f=lambda df: ~df['val'].isna())

wf = nvt.Workflow(out)
wf.fit_transform(ds).compute()

In [None]:
out = ['id', 'val'] >> nvt.ops.Filter(f=lambda df: ~df['val'].isna())

wf = nvt.Workflow(out + ['another_col'])
wf.fit_transform(ds).compute()

In [None]:
train.head()

In [None]:
skus.to_npy('embeddings.npy')

In [None]:
out = ['product_sku_hash', 'category_hash'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()
out += ['description_vector'] >> nvt.ops.TagAsItemFeatures()
out += ['price_bucket'] >> nvt.ops.NormalizeMinMax()

wf = nvt.Workflow(out)
skus = wf.fit_transform(skus)

In [None]:
train.head()

In [None]:
skus.head()

In [111]:
df = cudf.DataFrame(data={'id': [1,2,3], 'label': [1,2,1]})
ds = nvt.Dataset(df)

out = ['label'] >> nvt.ops.AddMetadata(Tags.TARGET)

wf = nvt.Workflow(out + ['id'])

ds_out = wf.fit_transform(ds)

loader = Loader(
    ds_out,
    batch_size=1,
)

loader.peek()

To use synthetically generated data, uncomment the following cell:

In [None]:
%%bash

cd /workspace && pip install . 

In [None]:
from merlin.datasets.synthetic import KNOWN_DATASETS

In [None]:
KNOWN_DATASETS

In [None]:
from merlin.datasets.synthetic import generate_data

generate_data('sigir-browsing', 1000).head()

In [None]:
generate_data('sigir-sku', 1000).head()