In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch import nn

from merlin.schema import Schema, ColumnSchema, Tags
from merlin.datasets.synthetic import generate_data
import merlin.models.torch as mm


train = generate_data("music-streaming", num_rows=1000)
schema = train.schema

schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.domain.min,properties.domain.max,properties.domain.name,properties.value_count.min,properties.value_count.max
0,session_id,"(Tags.CATEGORICAL, Tags.SESSION_ID, Tags.SESSI...","DType(name='int64', element_type=<ElementType....",False,False,0.0,10000.0,session_id,,
1,item_id,"(Tags.CATEGORICAL, Tags.ID, Tags.ITEM, Tags.IT...","DType(name='int64', element_type=<ElementType....",False,False,0.0,10000.0,item_id,,
2,item_category,"(Tags.CATEGORICAL, Tags.ITEM)","DType(name='int64', element_type=<ElementType....",False,False,0.0,100.0,item_category,,
3,item_recency,"(Tags.ITEM, Tags.CONTINUOUS)","DType(name='float64', element_type=<ElementTyp...",False,False,0.0,1.0,item_recency,,
4,item_genres,"(Tags.CATEGORICAL, Tags.ITEM)","DType(name='int64', element_type=<ElementType....",True,True,0.0,100.0,genres,4.0,
5,user_id,"(Tags.CATEGORICAL, Tags.USER_ID, Tags.ID, Tags...","DType(name='int64', element_type=<ElementType....",False,False,0.0,10000.0,user_id,,
6,country,"(Tags.CATEGORICAL, Tags.USER)","DType(name='int64', element_type=<ElementType....",False,False,0.0,100.0,country,,
7,user_age,"(Tags.CONTINUOUS, Tags.USER)","DType(name='int64', element_type=<ElementType....",False,False,18.0,50.0,user_age,,
8,user_genres,"(Tags.CATEGORICAL, Tags.USER)","DType(name='int64', element_type=<ElementType....",True,True,0.0,100.0,genres,4.0,
9,position,"(bias, Tags.CONTINUOUS)","DType(name='int64', element_type=<ElementType....",False,False,1.0,100.0,position,,


### Core Abstractions
- Batch
- Link
- Block
- ParallelBlock
- Model

#### ParallelBlock

In [5]:
class PlusOne(nn.Module):
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return inputs + 1
    
block = mm.ParallelBlock({"a": PlusOne(), "b": PlusOne()})
block.prepend(PlusOne())
block.append(PlusOne())
block.rich_print()

## Retrieval

#### Matrix Factorization

![mf](img/mf.png)

In [6]:
mf_inputs = mm.TabularInputBlock(schema)
mf_inputs.add_for_each([Tags.USER_ID, Tags.ITEM_ID], mm.EmbeddingTable(100))
# output = mm.ContrastiveOutput((Tags.USER_ID, Tags.ITEM_ID), schema=schema)
# model = mm.Model(mf_inputs, output)

mf_inputs

NotSupportedError: Comprehension ifs are not supported yet:
  File "/home/marc/anaconda3/envs/torchrec-pip/lib/python3.9/site-packages/merlin/schema/schema.py", line 634
        # must account for same columns in both schemas,
        # use the one with more information for each field
        keys_self_not_other = [
            col_name for col_name in self.column_names if col_name not in other.column_names
        ]
'__torch__.merlin.schema.schema.Schema' is being compiled since it was called from 'EmbeddingTable.forward'
  File "/home/marc/src/merlin/models/merlin/models/torch/block.py", line 389
    def setup_schema(self, schema: Schema):
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        self.schema = schema
        ~~~~~~~~~~~~~~~~~~~~ <--- HERE


#### Two tower

![tt](img/two-tower.png)

#### Youtube DNN Retrieval

![yt](img/youtube-dnn.png)

In [None]:
inputs = TabularInputBlock(schema, init="defaults")
output = ContrastiveOutput(Tags.ITEM_ID, schema=schema)
model = Model(inputs, MLPBlock([512, 256]), output)

## Ranking

#### DLRM

![dlrm](img/dlrm.png)

In [None]:
class ShortcutConcatContinuous(Link):
    def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        intermediate_output = self.output(inputs)

        return torch.cat((inputs["continuous"], intermediate_output), dim=1)


bottom_block = MLPBlock([256, 128])
top_block = MLPBlock([256, 128])

dlrm = TabularInputBlock(schema)
dlrm.add(Tags.CATEGORICAL, Embeddings(128))
dlrm.add(Tags.CONTINUOUS, bottom_block, name="continuous")
dlrm.append(DLRMInteraction(), link=ShortcutConcatContinuous())
# dlrm.append(DLRMInteraction(), link=Shortcut(post_fn=lambda x: torch.cat((x["output"], x["continuous"]), dim=1)))
dlrm.append(top_block)

model = Model(dlrm, BinaryOutput())

#### DCN-V2

![dcn](img/dcn-v2.png)

In [None]:
class DenseMaybeLowRank(nn.Module):
    ...

class Cross(Link):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = x
        current = x
        for module in self.output.values:
            current = x0 * module(current) + current

        return current


def CrossBlock(n: int, low_rank_dim: Optional[int] = None):
    cross_modules = Block(DenseMaybeLowRank(low_rank_dim)).repeat(n)

    return Block(name="CrossBlock").append(cross_modules, link=Cross()))


inputs = TabularInputBlock(schema, init="defaults")

dcn_parallel = ParallelBlock(name="DCNParallel")
dcn_parallel.prepend("concat")
dnc_parallel.branches["cross"] = CrossBlock(5)
dcn_parallel.branches["deep"] = MLPBlock([256, 128])
dcn_parallel.append("concat")

dcn_parallel = ParallelBlock(
    {"cross": CrossBlock(5), "deep": MLPBlock([256, 128])},
    name="DCNParallel",
    pre="concat",
    post="concat",
)


dcn_stacked = Block(CrossBlock(5), name="DCNStacked").append(MLPBlock([256, 128]))

#### Multi-task

![mmoe](img/mmoe.png)

In [None]:
expert = MLPBlock([256, 128])
outputs = OutputBlock(schema, init="defaults")
outputs.prepend_for_each(MLPBlock([256, 128]))

experts = ParallelBlock({
    "experts": Block.parse(expert).repeat_parallel(5, agg="stack")
}, shortcut=True)

# This will output create a ParallelBlock with a gate for each output
gates = Block(ExpertGate(pre_gate, len(outputs))).repeat_parallel_like(outputs)

mmoe = Block(experts, gates, outputs)

## Session based

#### Current

In [None]:
class BroadcastToSequence:
    def __init__(
        self,
        context_selection=Tags.CONTEXT,
        sequence_selection=Tags.SEQUENCE,
        schema: Optional[Schema] = None
    ):
        ...

    def setup_schema(self, schema: Schema):
        ...


inputs = TabularInputBlock(schema)
inputs.add(Tags.CONTINUOUS)
inputs.add(Tags.CATEGORICAL, Embeddings(128))
inputs.append(BroadcastToSequence())

#### Proposed alternative

In [None]:
class ConditionedTransformer(nn.Module):
    def __init__(
        self,
        input_key: str = "inputs",
        conditioning_key: str = "context",
        d_model: int = 512,
        nhead: int = 8,
        num_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super(ConditionedTransformer, self).__init__()

        # Initialize cross attention layer
        self.cross_attention = nn.MultiheadAttention(d_model, nhead, dropout)

        # Initialize self attention layers
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

        # Store input and conditioning keys
        self.input_key = input_key
        self.conditioning_key = conditioning_key

    def forward(self, input_dict: dict):
        # Extract inputs and context from the input dictionary
        inputs = input_dict[self.input_key]
        context = input_dict[self.conditioning_key]

        # Perform cross-attention on context
        cross_att_output, _ = self.cross_attention(inputs, context, context)

        # Perform self-attention on input_key (inputs)
        self_att_output = self.transformer(cross_att_output)

        return self_att_output


inputs = TabularInputBlock(schema, init="embeddings-128")
encoder = inputs.to_router()
encoder.add(Tags.CONTEXT, Block("stack", nn.LayerNorm(128), nn.Transformer(d_model=128)), name="context")
encoder.add(Tags.SEQUENCE, Block("stack"), name="inputs")
encoder.append(ConditionedTransformer(), link="residual")

## Tabular Transformers

#### ExcelFormer

![excel](img/excelformer.png)

In [None]:
class ExcelEmbeddings(nn.Module):
    def __init__(self, dim: int, schema: Optional[Schema] = None):
        super(ExcelEmbeddings, self).__init__()
        self.dim = dim
        self.concat = Concat()

    def setup_schema(self, schema: Schema):
        num_features = len(schema)

        self.W1 = nn.Parameter(torch.empty(num_features, self.dim))
        self.W2 = nn.Parameter(torch.empty(num_features, self.dim))
        self.b1 = nn.Parameter(torch.empty(num_features, self.dim))
        self.b2 = nn.Parameter(torch.empty(num_features, self.dim))

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W1)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.b1, -bound, bound)
        nn.init.uniform_(self.b2, -bound, bound)

    def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        x = self.concat(inputs)

        # x: Input tensor of shape (batch_size, num_features)
        z1 = torch.tanh(torch.matmul(x.unsqueeze(1), self.W1) + self.b1)
        z2 = torch.matmul(x.unsqueeze(1), self.W2) + self.b2
        z = z1 * z2
        return z.squeeze(1)