In [1]:
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from flax.core import freeze, unfreeze
from typing import List, Optional


class DLRM_Net(nn.Module):
    m_spa: int
    ln_emb: List[int]
    ln_bot: List[int]
    ln_top: List[int]
    arch_interaction_op: str
    arch_interaction_itself: bool = False
    sigmoid_bot: int = -1
    sigmoid_top: int = -1
    loss_threshold: float = 0.0
    weighted_pooling: Optional[str] = None

    def setup(self):
        self.embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa) 
                           for n in self.ln_emb]

        self.bot_mlp = self.create_mlp(self.ln_bot, self.sigmoid_bot)
        self.top_mlp = self.create_mlp(self.ln_top, self.sigmoid_top)

    def create_mlp(self, ln, sigmoid_layer):
        layers = []
        for i in range(len(ln) - 1):
            layers.append(nn.Dense(features=ln[i + 1]))
            if i == sigmoid_layer:
                layers.append(nn.sigmoid)
            else:
                layers.append(nn.relu)
        return nn.Sequential(layers)

    def apply_embedding(self, lS_o, lS_i, embeddings):
        """Embeddings lookup for sparse features using jax.lax.gather"""
        ly = []
        for k in range(len(embeddings)):
            E = embeddings[k]
            # Using lax.gather for embedding lookup
            gather_indices = jnp.expand_dims(lS_i[k], axis=-1)  # Shape (batch_size, 1)
            
            # Define the gather operation
            gathered_embeddings = lax.gather(
                E,
                gather_indices,
                dimension_numbers=lax.GatherDimensionNumbers(
                    offset_dims=(1,),  # Offsets in the embedding dimension
                    collapsed_slice_dims=(0,),  # Collapsing the index dimension
                    start_index_map=(0,)  # Mapping indices to the first dimension
                ),
                slice_sizes=(1, 128)  # Gather 1 slice along the first dimension, and full along the second
            )
            
            # Perform sum over the range of gathered embeddings specified by lS_o
            V = jax.vmap(lambda g, o: jnp.sum(g[:o], axis=0), in_axes=(0, 0))(gathered_embeddings, lS_o[k])
            ly.append(V)
        
        return ly

    def interact_features(self, x, ly):
        if self.arch_interaction_op == "dot":
            T = jnp.concatenate([x] + ly, axis=1).reshape(x.shape[0], -1, x.shape[1])
            Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
            offset = 1 if self.arch_interaction_itself else 0
            li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
            lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
            Zflat = Z[:, li, lj]
            R = jnp.concatenate([x, Zflat], axis=1)
        elif self.arch_interaction_op == "cat":
            R = jnp.concatenate([x] + ly, axis=1)
        else:
            raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
        return R

    def __call__(self, dense_x, lS_o, lS_i):
        x = self.bot_mlp(dense_x)
        ly = self.apply_embedding(lS_o, lS_i, self.embeddings)
        z = self.interact_features(x, ly)
        p = self.top_mlp(z)

        if 0.0 < self.loss_threshold < 1.0:
            p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)

        return p


In [2]:
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from flax import linen as nn
from flax.training import train_state
import optax

# Dummy Data Configuration
batch_size = 1  # Batch size for testing
num_dense_features = 10  # Number of dense features
num_sparse_features = 3  # Number of sparse features
num_embeddings = [20, 10, 5]  # Number of embedding entries per sparse feature
m_spa = 8  # Size of the embedding vector

# Create dummy dense features and sparse indices
dense_x = jnp.ones((batch_size, num_dense_features))  # Dense input features
lS_i = [jnp.ones((batch_size, 10), dtype=int) for _ in range(num_sparse_features)]  # Sparse indices
lS_o = [jnp.arange(0, batch_size * 10, 10) for _ in range(num_sparse_features)]  # Sparse offsets

# Model configuration
ln_bot = [num_dense_features, 64, 32]  # Bottom MLP layers
ln_top = [m_spa * (num_sparse_features + 1), 128, 64, 1]  # Top MLP layers (plus interaction)
arch_interaction_op = 'dot'  # Interaction operation

# Initialize the model
model = DLRM_Net(
    m_spa=m_spa,
    ln_emb=num_embeddings,
    ln_bot=ln_bot,
    ln_top=ln_top,
    arch_interaction_op=arch_interaction_op,
    arch_interaction_itself=False,
    sigmoid_bot=-1,  # No sigmoid in bottom MLP
    sigmoid_top=len(ln_top) - 2  # Sigmoid in the last layer before output
)

# Initialize parameters using a random key
key = jax.random.PRNGKey(0)
params = model.init(key, dense_x, lS_o, lS_i)

# Forward pass to test the model
output = model.apply(params, dense_x, lS_o, lS_i)

# Print the output
print("Output of the DLRM model (logits):")
print(output)


2024-10-01 03:04:30.466529: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


TypeError: Invalid start_index_map; domain is [0, 0), got: 0->0.