In [None]:
import tensorflow as tf

from merlin.schema import ColumnSchema, Schema, Tags
import merlin_models.tf as ml
from merlin_models.data.synthetic import generate_user_item_interactions
from merlin_models.utils.schema import create_categorical_column, create_continuous_column

def data_from_schema(schema, num_items=1000) -> tf.data.Dataset:
    data_df = generate_user_item_interactions(schema, num_items)

    targets = {}
    for target in synthetic_music_recsys_data_schema.select_by_tag(Tags.BINARY_CLASSIFICATION):
        targets[target.name] = data_df.pop(target.name)

    dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), targets))
    # dataset = dataset.batch(100)

    return dataset

In [None]:
synthetic_music_recsys_data_schema = Schema(
    [
        # Item
        create_categorical_column(
            "item_id",
            num_items=10000,
            tags=[Tags.ITEM_ID],
        ),
        create_categorical_column(
            "item_category",
            num_items=100,
            tags=[Tags.ITEM],
        ),
        create_continuous_column(
            "item_recency",
            min_value=0,
            max_value=1,
            tags=[Tags.ITEM],
        ),
        create_categorical_column(
            "item_genres",
            num_items=100,
            min_value_count=1,
            max_value_count=20,
            tags=[Tags.ITEM],
        ),

        # User
        create_categorical_column(
            "country",
            num_items=100,
            tags=[Tags.USER],
        ),
        create_continuous_column(
            "user_age",
            dtype="int32",
            min_value=18,
            max_value=50,
            tags=[Tags.USER],
        ),
        create_categorical_column(
            "user_genres",
            num_items=100,
            min_value_count=1,
            max_value_count=20,
        ),

        # Targets
        ColumnSchema("click").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
        ColumnSchema("play").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
        ColumnSchema("like").with_tags(tags=[Tags.BINARY_CLASSIFICATION]),
    ]
)

dataset = data_from_schema(synthetic_music_recsys_data_schema).batch(100)

In [None]:
def MMOE(expert_block, num_experts: int, output_names, gate_dim: int = 32):
    agg = ml.StackFeatures(axis=1)
    experts = expert_block.repeat_in_parallel(num_experts, prefix="expert_", aggregation=agg)
    gates = ml.MMOEGate(num_experts, dim=gate_dim).repeat_in_parallel(names=output_names)
    mmoe = expert_block.add_with_shortcut(experts).add(gates, block_name="MMOE")

    return mmoe

def build_advanced_ranking_model(schema: Schema) -> ml.Model:
    # TODO: Change msl to be able to make this a single function call.
    # bias_schema = schema.select_by_tag("bias")
    schema = schema.remove_by_tag("bias")

    body = ml.DLRMBlock(
        schema, bottom_block=ml.MLPBlock([512, 128]), top_block=ml.MLPBlock([512, 128])
    )
    # bias_block = ml.MLPBlock.from_schema(bias_schema, [64])
    # body = body.add_in_parallel(bias_block, names=["main", "bias"])

    head = ml.MMOEHead.from_schema(
        schema,
        body,
        task_blocks=ml.MLPBlock([64, 32]),
        expert_block=ml.MLPBlock([64, 32]),
        num_experts=3,
        # bias_block=bias_block,
    )
    # head.add_in_parallel()

    return head.to_model()

model = build_advanced_ranking_model(synthetic_music_recsys_data_schema)
model

In [None]:
def build_dnn(schema: Schema) -> ml.Model:
    body = ml.inputs(schema, ml.MLPBlock([512, 256]))

    return ml.Head.from_schema(schema, body).to_model()

model = build_dnn(synthetic_music_recsys_data_schema)
model

In [None]:
def build_dlrm(schema: Schema) -> ml.Model:
    body = ml.DLRMBlock(
        schema, bottom_block=ml.MLPBlock([512, 128]), top_block=ml.MLPBlock([512, 128])
    )

    return ml.Head.from_schema(synthetic_music_recsys_data_schema, body).to_model()

dlrm = build_dlrm(synthetic_music_recsys_data_schema)
dlrm

In [None]:
def build_two_tower(schema: Schema, target="play") -> ml.Model:
    body = ml.TwoTowerBlock(schema, query_tower=ml.MLPBlock([512, 256]))

    return ml.Head.from_schema(schema.select_by_name(target), body).to_model()

two_tower = build_two_tower(synthetic_music_recsys_data_schema, target="play")

two_tower