In [1]:
import tensorflow as tf

import merlin_standard_lib as msl
from merlin_standard_lib import Schema, Tag
from transformers4rec.data.synthetic import (
    synthetic_ecommerce_data_schema,
    generate_item_interactions
)

import merlin_models.tf as ml

from merlin_models.data.synthetic import generate_recsys_data

def data_from_schema(schema, num_items=1000) -> tf.data.Dataset:
    data_df = generate_recsys_data(num_items, schema)

    targets = {}
    for target in synthetic_music_recsys_data_schema.select_by_tag(Tag.BINARY_CLASSIFICATION):
        targets[target.name] = data_df.pop(target.name)

    dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), targets))
    # dataset = dataset.batch(100)

    return dataset

Init Plugin
Init Graph Optimizer
Init Kernel


In [2]:
synthetic_music_recsys_data_schema = Schema(
    [
        # Item
        msl.ColumnSchema.create_categorical(
            "item_id",
            num_items=10000,
            tags=[Tag.ITEM_ID],
        ),
        msl.ColumnSchema.create_categorical(
            "item_category",
            num_items=100,
            tags=[Tag.ITEM],
        ),
        msl.ColumnSchema.create_continuous(
            "item_recency",
            min_value=0,
            max_value=1,
            tags=[Tag.ITEM],
        ),
        msl.ColumnSchema.create_categorical(
            "item_genres",
            num_items=100,
            value_count=msl.schema.ValueCount(1, 20),
            tags=[Tag.ITEM],
        ),

        # User
        msl.ColumnSchema.create_categorical(
            "country",
            num_items=100,
            tags=[Tag.USER],
        ),
        msl.ColumnSchema.create_continuous(
            "user_age",
            is_float=False,
            min_value=18,
            max_value=50,
            tags=[Tag.USER],
        ),
        msl.ColumnSchema.create_categorical(
            "user_genres",
            num_items=100,
            value_count=msl.schema.ValueCount(1, 20),
        ),

        # Targets
        msl.ColumnSchema("click").with_tags(tags=[Tag.BINARY_CLASSIFICATION]),
        msl.ColumnSchema("play").with_tags(tags=[Tag.BINARY_CLASSIFICATION]),
        msl.ColumnSchema("like").with_tags(tags=[Tag.BINARY_CLASSIFICATION]),
    ]
)

dataset = data_from_schema(synthetic_music_recsys_data_schema).batch(100)

Metal device set to: Apple M1


2021-10-11 21:17:06.417677: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-10-11 21:17:06.417763: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
def build_dnn(schema: Schema) -> ml.Model:
    body = ml.MLPBlock.from_schema(schema, [512, 256])

    return ml.Head.from_schema(schema, body).to_model()

model = build_dnn(synthetic_music_recsys_data_schema)
model

Model(
  (heads): _TupleWrapper((Head(
    (body): MLPBlock(
      (layers): List(
        (0): TabularFeatures(
          (_aggregation): ConcatFeatures()
          (to_merge): Dict(
            (continuous_layer): ContinuousFeatures(item_recency, user_age)
            (categorical_layer): EmbeddingFeatures(
              (feature_config): Dict(
                (item_id): TableConfig(vocabulary_size=10001, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_id')
                (item_category): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_category')
                (item_genres): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_genres')
                (country): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='country')
                (user_genres): TableConfig(vocabulary_size=101, dim=64, initializer=None,