In [1]:
import tensorflow as tf

import merlin_standard_lib as msl
from merlin.schema import Schema, Tags
import merlin_models.tf as ml
from merlin_models.data.synthetic import generate_recsys_data

def data_from_schema(schema, num_items=1000) -> tf.data.Dataset:
    data_df = generate_recsys_data(num_items, schema)

    targets = {}
    for target in synthetic_music_recsys_data_schema.select_by_tag(Tags.BINARY_CLASSIFICATION):
        targets[target.name] = data_df.pop(target.name)

    dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), targets))
    # dataset = dataset.batch(100)

    return dataset

In [2]:
synthetic_music_recsys_data_schema = Schema(
    [
        # Item
        msl.ColumnSchema.create_categorical(
            "item_id",
            num_items=10000,
            tags=[Tags.ITEM_ID],
        ),
        msl.ColumnSchema.create_categorical(
            "item_category",
            num_items=100,
            tags=[Tags.ITEM],
        ),
        msl.ColumnSchema.create_continuous(
            "item_recency",
            min_value=0,
            max_value=1,
            tags=[Tags.ITEM],
        ),
        msl.ColumnSchema.create_categorical(
            "item_genres",
            num_items=100,
            value_count=msl.schema.ValueCount(1, 20),
            tags=[Tags.ITEM],
        ),

        # User
        msl.ColumnSchema.create_categorical(
            "country",
            num_items=100,
            tags=[Tags.USER],
        ),
        msl.ColumnSchema.create_continuous(
            "user_age",
            is_float=False,
            min_value=18,
            max_value=50,
            tags=[Tags.USER],
        ),
        msl.ColumnSchema.create_categorical(
            "user_genres",
            num_items=100,
            value_count=msl.schema.ValueCount(1, 20),
        ),

        # Targets
        msl.ColumnSchema("click").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
        msl.ColumnSchema("play").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
        msl.ColumnSchema("like").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
    ]
)

dataset = data_from_schema(synthetic_music_recsys_data_schema).batch(100)

2021-11-18 12:02:52.516265: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-18 12:02:52.517511: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-18 12:02:52.552688: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-18 12:02:52.554579: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-18 12:02:52.556447: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from S

In [4]:
def MMOE(expert_block, num_experts: int, output_names, gate_dim: int = 32):
    agg = ml.StackFeatures(axis=1)
    experts = expert_block.repeat_in_parallel(num_experts, prefix="expert_", aggregation=agg)
    gates = ml.MMOEGate(num_experts, dim=gate_dim).repeat_in_parallel(names=output_names)
    mmoe = expert_block.add_with_shortcut(experts).add(gates, block_name="MMOE")

    return mmoe

def build_advanced_ranking_model(schema: Schema) -> ml.Model:
    # TODO: Change msl to be able to make this a single function call.
    # bias_schema = schema.select_by_tag("bias")
    schema = schema.remove_by_tag("bias")

    body = ml.DLRMBlock(
        schema, bottom_block=ml.MLPBlock([512, 128]), top_block=ml.MLPBlock([512, 128])
    )
    # bias_block = ml.MLPBlock.from_schema(bias_schema, [64])
    # body = body.add_in_parallel(bias_block, names=["main", "bias"])

    head = ml.MMOEHead.from_schema(
        schema,
        body,
        task_blocks=ml.MLPBlock([64, 32]),
        expert_block=ml.MLPBlock([64, 32]),
        num_experts=3,
        # bias_block=bias_block,
    )
    # head.add_in_parallel()

    return head.to_model()

model = build_advanced_ranking_model(synthetic_music_recsys_data_schema)
model

Model(
  (heads): _TupleWrapper((MMOEHead(
    (body): SequentialBlock(
      (layers): List(
        (0): TabularFeatures(
          (parallel_layers): Dict(
            (continuous_layer): ContinuousFeatures(item_recency, user_age)
            (categorical_layer): EmbeddingFeatures(
              (feature_config): Dict(
                (item_id): TableConfig(vocabulary_size=10001, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_id')
                (item_category): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_category')
                (item_genres): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_genres')
                (country): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='country')
                (user_genres): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, com

In [5]:
def build_dnn(schema: Schema) -> ml.Model:
    body = ml.inputs(schema, ml.MLPBlock([512, 256]))

    return ml.Head.from_schema(schema, body).to_model()

model = build_dnn(synthetic_music_recsys_data_schema)
model

Model(
  (heads): _TupleWrapper((Head(
    (body): SequentialBlock(
      (layers): List(
        (0): TabularFeatures(
          (parallel_layers): Dict(
            (continuous_layer): ContinuousFeatures(item_recency, user_age)
            (categorical_layer): EmbeddingFeatures(
              (feature_config): Dict(
                (item_id): TableConfig(vocabulary_size=10001, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_id')
                (item_category): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_category')
                (item_genres): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='item_genres')
                (country): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='country')
                (user_genres): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='me

In [6]:
def build_dlrm(schema: Schema) -> ml.Model:
    body = ml.DLRMBlock(
        schema, bottom_block=ml.MLPBlock([512, 128]), top_block=ml.MLPBlock([512, 128])
    )

    return ml.Head.from_schema(synthetic_music_recsys_data_schema, body).to_model()

dlrm = build_dlrm(synthetic_music_recsys_data_schema)
dlrm

Model(
  (heads): _TupleWrapper((Head(
    (body): SequentialBlock(
      (layers): List(
        (0): TabularFeatures(
          (parallel_layers): Dict(
            (continuous_layer): ContinuousFeatures(item_recency, user_age)
            (categorical_layer): EmbeddingFeatures(
              (feature_config): Dict(
                (item_id): TableConfig(vocabulary_size=10001, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_id')
                (item_category): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_category')
                (item_genres): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='item_genres')
                (country): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combiner='mean', name='country')
                (user_genres): TableConfig(vocabulary_size=101, dim=128, initializer=None, optimizer=None, combine

In [7]:
def build_two_tower(schema: Schema, target="play") -> ml.Model:
    body = ml.TwoTowerBlock(schema, query_tower=ml.MLPBlock([512, 256]))

    return ml.Head.from_schema(schema.select_by_name(target), body).to_model()

two_tower = build_two_tower(synthetic_music_recsys_data_schema, target="play")

two_tower

Model(
  (heads): _TupleWrapper((Head(
    (body): ParallelBlock(
      (_aggregation): CosineSimilarity(
        (dot): Dot()
      )
      (parallel_layers): Dict(
        (user): SequentialBlock(
          (layers): List(
            (0): SequentialBlock(
              (layers): List(
                (0): TabularFeatures(
                  (parallel_layers): Dict(
                    (continuous_layer): ContinuousFeatures(user_age)
                    (categorical_layer): EmbeddingFeatures(
                      (feature_config): Dict(
                        (country): TableConfig(vocabulary_size=101, dim=64, initializer=None, optimizer=None, combiner='mean', name='country')
                      )
                      (_pre): SequentialTabularTransformations(
                        (layers): List(
                          (0): Filter(
                            (feature_names): List(
                              (0): 'country'
                            )
                   