In [1]:
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
from flax.core import freeze, unfreeze
from typing import List, Optional

import pdb

class MLP(nn.Module):
    ln: List[int]
    sigmoid_layer: int = -1
    @nn.compact
    def __call__(self,x):
        for i in range(len(self.ln) - 1):
            x = nn.Dense(features=self.ln[i + 1])(x)
            if i == self.sigmoid_layer:
                x  = nn.sigmoid(x)
            else:
                x = nn.relu(x)
        return nn.Dense(features=self.ln[-1])(x)
    


class DLRM_Net(nn.Module):
    m_spa: int
    ln_emb: List[int]
    ln_bot: List[int]
    ln_top: List[int]
    arch_interaction_op: str
    arch_interaction_itself: bool = False
    sigmoid_bot: int = -1
    sigmoid_top: int = -1
    loss_threshold: float = 0.0
    weighted_pooling: Optional[str] = None
 
    def apply_embedding(self, lS_o, lS_i, embeddings):
        """Embeddings lookup for sparse features using offsets."""
        ly = []
        for k in range(len(embeddings)):
            E = embeddings[k]
            indices = jnp.array(lS_i[k])
            offsets = jnp.array(lS_o[k])
            
            ## Perform embedding lookup using indices
            embeds = jnp.array(E(indices))
            #print(offsets.shape)
            ## Sum over ranges defined by the offsets (as we discussed earlier)
            output = []  # This is not idiomatic JAX, replace with jnp.array([])
            for i in range(offsets.shape[0] - 1):
                start, end = 0, 10 #offsets.take(i), offsets.take(i + 1)
                embed_arr = jax.lax.dynamic_slice_in_dim(embeds, start, end, axis=1)
                sum_embeddings = jnp.sum(embed_arr, axis=-1)
                
                output.append(sum_embeddings)
            # Append the summed embeddings for each sparse feature
            ly.append(jnp.stack(output))
            ly.append(jnp.stack(output))
        
        return ly

    def interact_features(self, x, ly):
        """Perform feature interactions between dense and sparse features."""
        if self.arch_interaction_op == "dot":
            # Concatenate dense features and sparse embeddings
            T = jnp.concatenate((jnp.expand_dims(x,0) , ly[0]),axis=2)#.reshape(x.shape[0], -1, x.shape[1])
            print(T.shape)
            
            Z = jnp.matmul(T, jnp.transpose(T, axes=(0, 2, 1)))
            print(Z.shape)
            #
            offset = 1 if self.arch_interaction_itself else 0
            li = jnp.array([i for i in range(Z.shape[1]) for j in range(i + offset)])
            lj = jnp.array([j for i in range(Z.shape[2]) for j in range(i + offset)])
            #
            Zflat = Z[:, li, lj]
            print(x.shape)
            print(Zflat.shape)
            R = jnp.concatenate((x, Zflat), axis=1)
        elif self.arch_interaction_op == "cat":
            R = jnp.concatenate((jnp.expand_dims(x,0) , ly[0]),axis=2)
        else:
            raise ValueError(f"Unsupported interaction op: {self.arch_interaction_op}")
        
        return R

    @nn.compact
    def __call__(self, dense_x, lS_o, lS_i):
        """Forward pass for DLRM."""
        # Apply bottom MLP to dense features
        x = MLP(self.ln_bot, self.sigmoid_bot)(dense_x)
        
        embeddings = [nn.Embed(num_embeddings=n, features=self.m_spa) 
                           for n in self.ln_emb]
    
    
        # Apply embedding lookup with offsets for sparse features
        ly = self.apply_embedding(lS_o, lS_i, embeddings)
        #
        ## Interact features between dense and sparse features
        z = self.interact_features(x, ly)
        #
        # Apply top MLP for final prediction
        p = MLP(self.ln_top, self.sigmoid_top)(z)
        # Optionally clip prediction based on loss threshold
        if 0.0 < self.loss_threshold < 1.0:
            p = jnp.clip(p, self.loss_threshold, 1.0 - self.loss_threshold)

        return x,ly #p


In [2]:
# Dummy Data Configuration
batch_size = 2  # Batch size for testing
num_dense_features = 10  # Number of dense features
num_sparse_features = 3  # Number of sparse features
num_embeddings = [20, 10, 5]  # Number of embedding entries per sparse feature
m_spa = 8  # Size of the embedding vector

# Create dummy dense features and sparse indices
dense_x = jnp.ones((batch_size, num_dense_features))  # Dense input features
lS_i = jnp.array([jnp.ones((batch_size, 10), dtype=int) for _ in range(num_sparse_features)])  # Sparse indices
lS_o = jnp.array([jnp.arange(0, batch_size * 10, 10) for _ in range(num_sparse_features)])  # Sparse offsets

# Model configuration
ln_bot = [num_dense_features, 64, 32]  # Bottom MLP layers
ln_top = [m_spa * (num_sparse_features + 1), 128, 64, 1]  # Top MLP layers (plus interaction)
arch_interaction_op = 'cat'  # Interaction operation

# Initialize the model
model = DLRM_Net(
    m_spa=m_spa,
    ln_emb=num_embeddings,
    ln_bot=ln_bot,
    ln_top=ln_top,
    arch_interaction_op=arch_interaction_op,
    arch_interaction_itself=False,
    sigmoid_bot=-1,  # No sigmoid in bottom MLP
    sigmoid_top=len(ln_top) - 2  # Sigmoid in the last layer before output
)

# Initialize parameters using a random key
key = jax.random.PRNGKey(0)
params = model.init(key, dense_x, lS_o, lS_i)

test = model.apply(params, dense_x, lS_o, lS_i)
print(test)

2024-10-16 12:13:15.887241: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


(Array([[ 0.27021882,  0.6180628 , -0.14204143,  0.46539277,  0.22652283,
         0.289755  , -0.3446836 , -0.57154536,  1.0664508 ,  0.38619608,
        -0.2058066 ,  0.08506541,  1.0736796 ,  0.12039478, -0.08093123,
         0.23242348, -0.16108775, -0.6048284 ,  0.37975907, -1.076916  ,
        -0.21785256,  0.432528  , -0.21895711, -0.6404866 ,  0.530484  ,
        -0.906347  , -0.6153416 , -0.8004427 , -0.5313179 , -0.8547554 ,
        -0.7557017 , -1.23112   ],
       [ 0.27021882,  0.6180628 , -0.14204143,  0.46539277,  0.22652283,
         0.289755  , -0.3446836 , -0.57154536,  1.0664508 ,  0.38619608,
        -0.2058066 ,  0.08506541,  1.0736796 ,  0.12039478, -0.08093123,
         0.23242348, -0.16108775, -0.6048284 ,  0.37975907, -1.076916  ,
        -0.21785256,  0.432528  , -0.21895711, -0.6404866 ,  0.530484  ,
        -0.906347  , -0.6153416 , -0.8004427 , -0.5313179 , -0.8547554 ,
        -0.7557017 , -1.23112   ]], dtype=float32), [Array([[[-0.805229, -0.805229, -0.8

# Get Data

We use the data for pytorch dlrm, namely the criteo advertising challenge dataset.

In [3]:
#!wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
#!md5sum criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz
# df9b1b3766d9ff91d5ca3eb3d23bed27
#!mkdir data
!tar -xzf criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz -C data

tar: Ignoring unknown extended header keyword 'SCHILY.dev'
tar: Ignoring unknown extended header keyword 'SCHILY.ino'
tar: Ignoring unknown extended header keyword 'SCHILY.nlink'
tar: data: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now


  pid, fd = os.forkpty()


In [16]:
# experiment params
num_train_trials = 5
num_train_warmups = 1
num_jit_trials = 10
num_jit_warmups = 2
num_inference_trials = 10000
num_inference_warmups = 1000

# Timing

Run the JIT timing loop

In [5]:
import tqdm
import time

def time_jit(num_jit_runs: int, num_jit_warmups: int) -> float:
    times = []
    for i in tqdm.trange(1, num_jit_runs + 1):
        start_jit_time = time.time()
        
        jit_model = jax.jit(model.apply, backend='gpu').lower(params, dense_x, lS_o, lS_i)
        compiled_model = jit_model.compile()

        end_jit_time = time.time()
        jax.clear_caches()
        if i >= num_jit_warmups:
            times.append(end_jit_time - start_jit_time)

    return sum(times) / len(times)


average_jit_time = time_jit(num_jit_trials, num_jit_warmups)
print(f"Average JIT time: {average_jit_time}")

100%|████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.03it/s]

Average JIT time: 0.1318470901913113





In [6]:
jit_model = jax.jit(model.apply, backend='gpu').lower(params, dense_x, lS_o, lS_i)
compiled_model = jit_model.compile()

res = compiled_model(params, dense_x, lS_o, lS_i)

In [12]:
import optax

# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched, o, i):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y, o, i):
    pred = model.apply(params, x, o, i)
    # TODO unsure how to calculate loss here
    return 0.3
  # Vectorize the previous to compute the average of the loss on all samples.
  #return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return squared_error(x_batched, y_batched, o, i)


def time_train(num_train_runs: int, num_train_warmups: int, 
               params, lS_o, lS_i, res_x):

    optim = optax.adam(learning_rate=0.1)
    opt_state = optim.init(params)
    loss_grad_fn = jax.value_and_grad(mse)

    for i in tqdm.trange(num_train_trials):
        times = []
        start_train_time = time.time()
    
        for i in range(1000):
            loss_val, grads = loss_grad_fn(params, dense_x, res_x, lS_o, lS_i)
            updates, opt_state = optim.update(grads, opt_state)
            params = optax.apply_updates(params, updates)

        end_train_time = time.time()
        if i >= num_train_warmups:
            times.append(end_train_time - start_train_time)

    return sum(times) / len(times)

average_training_time = time_train(num_train_trials, num_train_warmups, 
                                params, lS_o, lS_i, res)
print(f"Average training time: {average_training_time}")

100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:45<00:00,  9.05s/it]

Average training time: 10.200698375701904





In [17]:
def time_inference(num_inference_trials, num_inference_warmups):
    times = []
    for i in tqdm.trange(num_inference_trials):
        start_inference_time = time.time()

        compiled_model(params, dense_x, lS_o, lS_i)

        end_inference_time = time.time()
        if i >= num_inference_warmups:
            times.append(end_inference_time - start_inference_time)

    return sum(times) / len(times)

average_inference_time = time_inference(num_inference_trials, num_inference_warmups)
print(f"Average inference time: {average_inference_time}")

100%|███████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 38468.17it/s]

Average inference time: 2.4001598358154297e-05





In [18]:
# export the python file
!jupyter nbconvert --to script DLRM-JAX.ipynb
!mv DLRM-JAX.py gen_dlrm_jax.py

# write the results to the results.csv file
import csv
with open('results.csv', 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Model", "Framework", "Evaluation", "Trials", "Warmups", "Time", "Notes"])
    csvwriter.writerow(["DLRM", "JAX", "train", 
                       num_train_trials, 
                       num_train_warmups, 
                       average_training_time, 
                       ""])
    csvwriter.writerow(["DLRM", "JAX", "inference", 
                       num_inference_trials, 
                       num_inference_warmups, 
                       average_inference_time, 
                       ""])
    csvwriter.writerow(["DLRM", "JAX", "JIT",
                        num_jit_trials,
                        num_jit_warmups,
                        average_jit_time,
                        ""])

  pid, fd = os.forkpty()


[NbConvertApp] Converting notebook DLRM-JAX.ipynb to script
[NbConvertApp] Writing 9975 bytes to DLRM-JAX.py
