In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =====

## 3. Customize and Extend Merlin Models

In this lab, ..

**Learning Objectives of this lab**

- Customize and extend recommende models with Merlin Models

**Import Required Libraries**

In [2]:
import os

import glob
import cudf 
import pandas as pd
import numpy as np
import nvtabular as nvt
from nvtabular.ops import *
import gc

from merlin.schema.tags import Tags
import merlin.models.tf as mm
from merlin.io.dataset import Dataset

import tensorflow as tf

2022-08-19 15:08:18.909473: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-19 15:08:21.351720: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8080 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB-N, pci bus id: 0000:06:00.0, compute capability: 7.0
2022-08-19 15:08:21.511773: I tensorflow/stream_executor/cuda/cuda_driver.cc:739] failed to allocate 7.89G (8472494080 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2022-08-19 15:08:21.597146: I tensorflow/stream_executor/cuda/cuda_driver.cc:739] failed to allocate 7.10G (7625244672 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2022-08-19 15:08:21.60167

In [3]:
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

In [4]:
data_path = '/workspace/data/ecom/'
output_path = os.path.join(data_path,'processed_nvt')

Read processed parquet files as Dataset objects.

In [5]:
train = Dataset(os.path.join(output_path, "train", "*.parquet"), part_size="500MB")
valid = Dataset(os.path.join(output_path, "valid", "*.parquet"), part_size="500MB")

# define schema object
schema = train.schema



In [6]:
target_column = schema.select_by_tag(Tags.TARGET).column_names[0]
target_column

'target'

In [7]:
schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.num_buckets,properties.freq_threshold,properties.max_size,properties.start_index,properties.cat_path,properties.embedding_sizes.cardinality,properties.embedding_sizes.dimension,properties.domain.min,properties.domain.max
0,user_id,"(Tags.CATEGORICAL, Tags.USER_ID, Tags.USER)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.user_id.parquet,351050.0,512.0,0.0,351050.0
1,ts_weekday,"(Tags.CATEGORICAL, Tags.USER)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.ts_weekday.parquet,8.0,16.0,0.0,8.0
2,ts_hour,"(Tags.CATEGORICAL, Tags.USER)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.ts_hour.parquet,25.0,16.0,0.0,25.0
3,product_id,"(Tags.CATEGORICAL, Tags.ITEM_ID, Tags.ITEM)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.product_id.parquet,51425.0,512.0,0.0,51425.0
4,cat_0,"(Tags.CATEGORICAL, Tags.ITEM)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.cat_0.parquet,14.0,16.0,0.0,14.0
5,cat_1,"(Tags.CATEGORICAL, Tags.ITEM)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.cat_1.parquet,61.0,16.0,0.0,61.0
6,cat_2,"(Tags.CATEGORICAL, Tags.ITEM)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.cat_2.parquet,90.0,20.0,0.0,90.0
7,brand,"(Tags.CATEGORICAL, Tags.ITEM)",int32,False,False,,0.0,0.0,0.0,.//categories/unique.brand.parquet,2654.0,132.0,0.0,2654.0
8,price,"(Tags.CONTINUOUS, Tags.ITEM)",float32,False,False,,,,,,,,,
9,relative_price,"(Tags.CONTINUOUS, Tags.ITEM)",float32,False,False,,,,,,,,,


In [8]:
batch = mm.sample_batch(train, batch_size=16, shuffle=False, include_targets=False)

### 1. Add HashedCross features to DLRM Model

In [16]:
continuous_block = mm.ContinuousFeatures.from_schema(schema, tags=Tags.CONTINUOUS)
bottom_block = continuous_block.connect(mm.MLPBlock([128,64]))

In [24]:
# emb_init = tf.keras.initializers.TruncatedNormal(
#     mean=0.0, stddev=0.05
# )

from merlin.models.utils.schema_utils import infer_embedding_dim

embeddings_block = mm.Embeddings(
    schema.select_by_tag(Tags.CATEGORICAL),
    dim = 64,
    infer_dim_fn = infer_embedding_dim
)

In [25]:
embeddings_block

ParallelBlock(
  (parallel_layers): Dict(
    (user_id): EmbeddingTable(
      (table): Embedding()
    )
    (ts_weekday): EmbeddingTable(
      (table): Embedding()
    )
    (ts_hour): EmbeddingTable(
      (table): Embedding()
    )
    (product_id): EmbeddingTable(
      (table): Embedding()
    )
    (cat_0): EmbeddingTable(
      (table): Embedding()
    )
    (cat_1): EmbeddingTable(
      (table): Embedding()
    )
    (cat_2): EmbeddingTable(
      (table): Embedding()
    )
    (brand): EmbeddingTable(
      (table): Embedding()
    )
  )
)

In [26]:
dlrm_input_block = mm.ParallelBlock(
    {"embeddings": embeddings_block, "bottom_block": bottom_block}
)

In [27]:
from merlin.models.tf.blocks.dlrm import DotProductInteractionBlock

dlrm_interaction = dlrm_input_block.connect_with_shortcut(
    DotProductInteractionBlock(), shortcut_filter=mm.Filter("bottom_block"), aggregation="concat"
)

In [None]:
cross_schema = schema.select_by_name(names=["cat_0", "cat_1"])
cross = mm.HashedCross(cross_schema, num_bins=10, output_mode="one_hot")

In [None]:
cross(batch)

In [None]:
'''
feature crossing with HashedCross class, creates a new feature
take the weighted some 
'''

cross_body = mm.HashedCross(cross_schema, num_bins=1000, output_mode="one_hot").connect(
    mm.MLPBlock([1], no_activation_last_layer=True), block_name='cross_model'
)

In [28]:
dlrm_with_crossbody = mm.ParallelBlock(
    {"dlrm_interaction": dlrm_interaction, "cross_body": cross_body},
    aggregation="concat"
)

In [29]:
dlrm_with_cross = dlrm_with_crossbody.connect(mm.MLPBlock([64, 128, 256]))

In [30]:
from merlin.models.tf.core.transformations import LogitsTemperatureScaler

binary_task = mm.BinaryClassificationTask(
    schema,
    pre=LogitsTemperatureScaler(temperature=2),
)

In [31]:
model = mm.Model(dlrm_with_cross, binary_task)

In [32]:
%%time 
model.compile(optimizer='adam', run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
model.fit(train, validation_data=valid, batch_size=4096, epochs=2)

Epoch 1/2
Epoch 2/2
CPU times: user 59.2 s, sys: 6.98 s, total: 1min 6s
Wall time: 39.1 s


<keras.callbacks.History at 0x7fed66227790>

### 2. Replace `DotProductInteractionBlock` with `CrossBlock`

In [23]:
continuous_block = mm.ContinuousFeatures.from_schema(schema, tags=Tags.CONTINUOUS)
bottom_block = continuous_block.connect(mm.MLPBlock([128,64]))

In [24]:
emb_init = tf.keras.initializers.TruncatedNormal(
    mean=0.0, stddev=0.05
)

embeddings_block = mm.Embeddings(
    schema.select_by_tag(Tags.CATEGORICAL),
    dim = 64
)

In [25]:
embeddings_block

ParallelBlock(
  (parallel_layers): Dict(
    (user_id): EmbeddingTable(
      (table): Embedding()
    )
    (ts_weekday): EmbeddingTable(
      (table): Embedding()
    )
    (ts_hour): EmbeddingTable(
      (table): Embedding()
    )
    (product_id): EmbeddingTable(
      (table): Embedding()
    )
    (cat_0): EmbeddingTable(
      (table): Embedding()
    )
    (cat_1): EmbeddingTable(
      (table): Embedding()
    )
    (cat_2): EmbeddingTable(
      (table): Embedding()
    )
    (brand): EmbeddingTable(
      (table): Embedding()
    )
  )
)

In [26]:
embeddings = embeddings_block(batch)
embeddings.keys(), embeddings["user_id"].shape

(dict_keys(['user_id', 'ts_weekday', 'ts_hour', 'product_id', 'cat_0', 'cat_1', 'cat_2', 'brand']),
 TensorShape([16, 64]))

In [27]:
embeddings.keys(), embeddings["cat_0"].shape

(dict_keys(['user_id', 'ts_weekday', 'ts_hour', 'product_id', 'cat_0', 'cat_1', 'cat_2', 'brand']),
 TensorShape([16, 64]))

In [33]:
dlrm_input_block = mm.ParallelBlock(
    {"embeddings": embeddings_block, "bottom_block": bottom_block},
    aggregation="concat"
)

# print("Output shapes of DLRM input block:")
# for key, val in dlrm_input_block(batch).items():
#     print("\t%s : %s" % (key, val.shape))

In [35]:
#dlrm_input_block

In [36]:

#stacked
dcn_body = dlrm_input_block.connect(mm.CrossBlock(2))

In [27]:
#dcn_body

In [37]:
dlrm_interaction = mm.ParallelBlock(
    {"dcn_body": dcn_body, "bottom_block": bottom_block},
    aggregation="concat"
)                                                

In [38]:
deep_dlrm_interaction = dlrm_interaction.connect(mm.MLPBlock([64, 128, 512]))
deep_dlrm_interaction(batch)

<tf.Tensor: shape=(16, 512), dtype=float32, numpy=
array([[0.0249111 , 0.02031652, 0.        , ..., 0.00059985, 0.        ,
        0.        ],
       [0.01873957, 0.01652094, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00175109, 0.03225915, 0.0024677 , ..., 0.00148939, 0.        ,
        0.        ],
       ...,
       [0.        , 0.02493227, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.01736411, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.02748184, 0.        , ..., 0.        , 0.        ,
        0.        ]], dtype=float32)>

In [44]:
from merlin.models.tf.core.transformations import LogitsTemperatureScaler

binary_task = mm.BinaryClassificationTask(
    schema,
    pre=LogitsTemperatureScaler(temperature=2),
)

In [45]:
model = mm.Model(deep_dlrm_interaction, binary_task)

In [46]:
%%time 
model.compile(optimizer='adam', run_eagerly=False, metrics=[tf.keras.metrics.AUC()])
model.fit(train, validation_data=valid, batch_size=4096, epochs=2)

Epoch 1/2
Epoch 2/2
CPU times: user 52.3 s, sys: 7.52 s, total: 59.8 s
Wall time: 37.8 s


<keras.callbacks.History at 0x7f835e8caaf0>

### Summary 

In this hands-on lab we learned ...