In [1]:
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from flax.core import freeze, unfreeze
from typing import List, Optional

'''
class DLRM_Net(nn.Module):
    m_spa: int
    ln_emb: List[int]
    ln_bot: List[int]
    ln_top: List[int]
    arch_interaction_op: str
    arch_interaction_itself: bool = False
    sigmoid_bot: int = -1
    sigmoid_top: int = -1
    loss_threshold: float = 0.0
    weighted_pooling: Optional[str] = None

    def setup(self):
        # Define embeddings as a list of embedding tables
        self.embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa) 
                           for n in self.ln_emb]

        # Define MLPs
        self.bot_mlp = self.create_mlp(self.ln_bot, self.sigmoid_bot)
        self.top_mlp = self.create_mlp(self.ln_top, self.sigmoid_top)

    def create_mlp(self, ln, sigmoid_layer):
        """Helper function to create MLP layers."""
        layers = []
        for i in range(len(ln) - 1):
            layers.append(nn.Dense(features=ln[i + 1]))
            if i == sigmoid_layer:
                layers.append(nn.sigmoid)
            else:
                layers.append(nn.relu)
        return nn.Sequential(layers)

    def apply_embedding(self, lS_o, lS_i, embeddings):
        """Embeddings lookup for sparse features using offsets."""
        ly = []
        for k in range(len(embeddings)):
            E = embeddings[k]
            indices = lS_i[k]
            offsets = lS_o[k]
            
            # Perform embedding lookup using indices
            embeds = E(indices)

            print(offsets.shape)
            
            # Sum over ranges defined by the offsets (as we discussed earlier)
            output = []
            for i in range(offsets.shape[0] - 1):
                start, end = offsets[i], offsets[i + 1]
                sum_embeddings = jnp.sum(embeds[start:end], axis=-1)
                output.append(sum_embeddings)

            # Append the summed embeddings for each sparse feature
            ly.append(jnp.stack(output))
        
        return ly

    def interact_features(self, x, ly):
        """Perform feature interactions between dense and sparse features."""
        if self.arch_interaction_op == "dot":
            # Concatenate dense features and sparse embeddings
            T = jnp.concatenate([x] + ly, axis=1).reshape(x.shape[0], -1, x.shape[1])
            Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
            
            offset = 1 if self.arch_interaction_itself else 0
            li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
            lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
            
            Zflat = Z[:, li, lj]
            R = jnp.concatenate([x, Zflat], axis=1)
        elif self.arch_interaction_op == "cat":
            R = jnp.concatenate([x] + ly, axis=1)
        else:
            raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
        
        return R

    def __call__(self, dense_x, lS_o, lS_i):
        """Forward pass for DLRM."""
        # Apply bottom MLP to dense features
        x = self.bot_mlp(dense_x)
        
        # Apply embedding lookup with offsets for sparse features
        ly = self.apply_embedding(lS_o, lS_i, self.embeddings)
        
        # Interact features between dense and sparse features
        z = self.interact_features(x, ly)
        
        # Apply top MLP for final prediction
        p = self.top_mlp(z)

        # Optionally clip prediction based on loss threshold
        if 0.0 < self.loss_threshold < 1.0:
            p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)

        return p
'''

'\nclass DLRM_Net(nn.Module):\n    m_spa: int\n    ln_emb: List[int]\n    ln_bot: List[int]\n    ln_top: List[int]\n    arch_interaction_op: str\n    arch_interaction_itself: bool = False\n    sigmoid_bot: int = -1\n    sigmoid_top: int = -1\n    loss_threshold: float = 0.0\n    weighted_pooling: Optional[str] = None\n\n    def setup(self):\n        # Define embeddings as a list of embedding tables\n        self.embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa) \n                           for n in self.ln_emb]\n\n        # Define MLPs\n        self.bot_mlp = self.create_mlp(self.ln_bot, self.sigmoid_bot)\n        self.top_mlp = self.create_mlp(self.ln_top, self.sigmoid_top)\n\n    def create_mlp(self, ln, sigmoid_layer):\n        """Helper function to create MLP layers."""\n        layers = []\n        for i in range(len(ln) - 1):\n            layers.append(nn.Dense(features=ln[i + 1]))\n            if i == sigmoid_layer:\n                layers.append(nn.sigmoid)\n       

In [27]:
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from flax.core import freeze, unfreeze
from typing import List, Optional

import pdb

class MLP(nn.Module):
    ln: List[int]
    sigmoid_layer: int = -1
    @nn.compact
    def __call__(self,x):
        for i in range(len(self.ln) - 1):
            x = nn.Dense(features=self.ln[i + 1])(x)
            if i == self.sigmoid_layer:
                x  = nn.sigmoid(x)
            else:
                x = nn.relu(x)
        return nn.Dense(features=self.ln[-1])(x)
    


class DLRM_Net(nn.Module):
    m_spa: int
    ln_emb: List[int]
    ln_bot: List[int]
    ln_top: List[int]
    arch_interaction_op: str
    arch_interaction_itself: bool = False
    sigmoid_bot: int = -1
    sigmoid_top: int = -1
    loss_threshold: float = 0.0
    weighted_pooling: Optional[str] = None
 
    def apply_embedding(self, lS_o, lS_i, embeddings):
        """Embeddings lookup for sparse features using offsets."""
        ly = []
        for k in range(len(embeddings)):
            E = embeddings[k]
            indices = lS_i[k]
            offsets = lS_o[k]
            
            ## Perform embedding lookup using indices
            embeds = E(indices)
            #print(offsets.shape)
            ## Sum over ranges defined by the offsets (as we discussed earlier)
            output = []
            for i in range(offsets.shape[0] - 1):
                start, end = offsets[i], offsets[i + 1]
                sum_embeddings = jnp.sum(embeds[start:end], axis=-1)
                output.append(sum_embeddings)
            # Append the summed embeddings for each sparse feature
            ly.append(jnp.stack(output))
        
        return ly

    def interact_features(self, x, ly):
        breakpoint()
        """Perform feature interactions between dense and sparse features."""
        if self.arch_interaction_op == "dot":
            # Concatenate dense features and sparse embeddings
            T = jnp.concatenate((jnp.expand_dims(x,0) , ly[0]),axis=2)#.reshape(x.shape[0], -1, x.shape[1])
            print(T.shape)
            
            Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
            print(Z.shape)
            #
            offset = 1 if self.arch_interaction_itself else 0
            li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
            lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
            #
            Zflat = Z[:, li, lj]
            print(x.shape)
            print(Zflat.shape)
            R = jnp.concatenate((x, Zflat), axis=1)
        elif self.arch_interaction_op == "cat":
            R = jnp.concatenate((jnp.expand_dims(x,0) , ly[0]),axis=2)
        else:
            raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
        
        return R

    @nn.compact
    def __call__(self, dense_x, lS_o, lS_i):
        """Forward pass for DLRM."""
        # Apply bottom MLP to dense features
        x = MLP(self.ln_bot, self.sigmoid_bot)(dense_x)
        
        embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa) 
                           for n in self.ln_emb]
    
    
        # Apply embedding lookup with offsets for sparse features
        ly = self.apply_embedding(lS_o, lS_i, embeddings)
        #
        ## Interact features between dense and sparse features
        z = self.interact_features(x, ly)
        #
        # Apply top MLP for final prediction
        p = MLP(self.ln_top, self.sigmoid_top)(z)
        # Optionally clip prediction based on loss threshold
        if 0.0 < self.loss_threshold < 1.0:
            p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)

        return x,ly #p


In [29]:
# Dummy Data Configuration
batch_size = 2  # Batch size for testing
num_dense_features = 10  # Number of dense features
num_sparse_features = 3  # Number of sparse features
num_embeddings = [20, 10, 5]  # Number of embedding entries per sparse feature
m_spa = 8  # Size of the embedding vector

# Create dummy dense features and sparse indices
dense_x = jnp.ones((batch_size, num_dense_features))  # Dense input features
lS_i = jnp.array([jnp.ones((batch_size, 10), dtype=int) for _ in range(num_sparse_features)])  # Sparse indices
lS_o = jnp.array([jnp.arange(0, batch_size * 10, 10) for _ in range(num_sparse_features)])  # Sparse offsets

# Model configuration
ln_bot = [num_dense_features, 64, 32]  # Bottom MLP layers
ln_top = [m_spa * (num_sparse_features + 1), 128, 64, 1]  # Top MLP layers (plus interaction)
arch_interaction_op = 'cat'  # Interaction operation

# Initialize the model
model = DLRM_Net(
    m_spa=m_spa,
    ln_emb=num_embeddings,
    ln_bot=ln_bot,
    ln_top=ln_top,
    arch_interaction_op=arch_interaction_op,
    arch_interaction_itself=False,
    sigmoid_bot=-1,  # No sigmoid in bottom MLP
    sigmoid_top=len(ln_top) - 2  # Sigmoid in the last layer before output
)

# Initialize parameters using a random key
key = jax.random.PRNGKey(0)
params = model.init(key, dense_x, lS_o, lS_i)

test = model.apply(params, dense_x, lS_o, lS_i)
print(test)

(Array([[ 0.27021882,  0.6180628 , -0.14204143,  0.46539277,  0.22652283,
         0.289755  , -0.3446836 , -0.57154536,  1.0664508 ,  0.38619608,
        -0.2058066 ,  0.08506541,  1.0736796 ,  0.12039478, -0.08093123,
         0.23242348, -0.16108775, -0.6048284 ,  0.37975907, -1.076916  ,
        -0.21785256,  0.432528  , -0.21895711, -0.6404866 ,  0.530484  ,
        -0.906347  , -0.6153416 , -0.8004427 , -0.5313179 , -0.8547554 ,
        -0.7557017 , -1.23112   ],
       [ 0.27021882,  0.6180628 , -0.14204143,  0.46539277,  0.22652283,
         0.289755  , -0.3446836 , -0.57154536,  1.0664508 ,  0.38619608,
        -0.2058066 ,  0.08506541,  1.0736796 ,  0.12039478, -0.08093123,
         0.23242348, -0.16108775, -0.6048284 ,  0.37975907, -1.076916  ,
        -0.21785256,  0.432528  , -0.21895711, -0.6404866 ,  0.530484  ,
        -0.906347  , -0.6153416 , -0.8004427 , -0.5313179 , -0.8547554 ,
        -0.7557017 , -1.23112   ]], dtype=float32), [Array([[[-0.805229, -0.805229, -0.8

# Get Data

We use the data for pytorch dlrm, namely the criteo advertising challenge dataset.

In [None]:
#!wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
#!md5sum criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
# df9b1b3766d9ff91d5ca3eb3d23bed27
#!mkdir data
!tar -xzf criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz -C data

In [None]:
# experiment params
num_train_trials = 5
num_train_warmups = 1
num_jit_trials = 10
num_jit_warmups = 2
num_inference_trials = 100
num_inference_warmups = 10

# JIT

Run the JIT timing loop