##### Copyright 2025 Google LLC.
Licensed under the Apache 2.0 License.

In [None]:
# @title Licensed under the Apache 2.0 License (the "License"); { display-mode: "form" }
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Implementation of a data adapter for multimodal LLM input

One way to train an LLM to use non-text data is to feed that data into the LLM's
context window via an adapter. When you input natural language text into an LLM,
it first gets transformed into tokens and then token embeddings. Each token
embedding is a one-dimensional tensor of size D. This results in an input of K
tokens mapping to a K by D matrix of inputs to the transfomer part of the LLM.

Given this, a simple way to inject non-text data into the input stream of the
LLM is to use an adapter to map the non-text data into a set of vectors of size
D, which represent virtual tokens. For example, let's say we have a matrix
dataset of dimension 10 by 100 that represents some features relevant to a
single input to the LLM (a single sample). We feed the 10x100 matrix into an
adapter that maps it into a T by D matrix, where T is some number of virtual
tokens. Then we simply prepend this to our input matrix yielding a (K+T) by D
matrix, which is input to the LLM.

For training and evaluation of PROs in the PH-LLM paper (https://arxiv.org/abs/2406.06474), each input sample had a
data input of size 40, which represents the concatenation of 20 mean values and
20 variance values for 20 sensors. In this case, we learned an MLP adapter
(implemented below) that maps the 1D tensor of length 40 to a 10 by 128
dimensional matrix of virtual token embeddings. We then concatenated the virtual
token matrix to the input matrix for the LLM.

Here we provide a simple implementation of an adapter layer in pax that could be
used to generate the virtual tokens.

## Install relevant packages

In [None]:
!pip install numpy
!pip install praxis

In [None]:
import jax
from jax import numpy as jnp
import numpy as np
from praxis import base_layer
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis.layers import linears

JTensor = base_layer.JTensor
NestedMap = py_utils.NestedMap

## Define Constants

In [None]:
# Size of the token embeddings used in the LLM transformer.
MODEL_DIM = 128

# Dimension of the input data to the adapter.
DATA_INPUT_DIM = 40

# Number of virtual tokens to use in the adapter.
NUM_VIRTUAL_TOKENS = 10

# Number of MLP layers to use in the data adapter.
NUM_MLP_LAYERS = 5

# Batch size used in training.
BATCH_SIZE = 8

# Random seed
RANDOM_SEED = 123

In [None]:
class SimpleAdapter(base_layer.BaseLayer):
  """A simple adapter layer for vector input data."""

  # The dimension of input data.
  input_dims: int = 1
  # The dimension of the token embeddings.
  token_embedding_dim: int = 128
  # The number of virtual tokens to allocate to data encoding.
  num_virtual_tokens: int = 10
  # Number of layers for the MLP used in the adapter.
  num_layers: int = 5

  def setup(self) -> None:
    ffn_p = pax_fiddle.Config(
        linears.FeedForward,
        name='ffn',
        input_dims=self.input_dims,
        output_dims=self.token_embedding_dim * self.num_virtual_tokens,
    )
    mlp_p = pax_fiddle.Config(
        linears.MLPBlock,
        num_layers=self.num_layers,
        name='mlp',
        activate_final=False,
        ff_tpl=ffn_p,
    )
    self.create_child('mlp_layer', mlp_p)

  def __call__(self, input: JTensor):
    """Apply the layer."""
    batch_size, _ = input.shape
    x = self.mlp_layer(input)
    return jnp.reshape(
        x, [batch_size, self.num_virtual_tokens, self.token_embedding_dim]
    )

## Test the layer with synthetic data.

In [None]:
_adapter_config = pax_fiddle.Config(
    SimpleAdapter,
    name='adapter_config',
    input_dims=DATA_INPUT_DIM,
    token_embedding_dim=MODEL_DIM,
    num_virtual_tokens=NUM_VIRTUAL_TOKENS,
    num_layers=NUM_MLP_LAYERS,
)
g_adapter_layer = base_layer.instantiate(_adapter_config)

In [None]:
# We create an input matrix of size BATCH_SIZE by DATA_INPUT_DIM.
# For example, with the default constants this represents
# a matrix of 8 samples where each has a data input vector of size 40.
_inputs_np = np.random.normal(
    size=(
        BATCH_SIZE,
        DATA_INPUT_DIM,
    )
)
_inputs = jnp.asarray(_inputs_np)

print(f'We expect {BATCH_SIZE} samples by {DATA_INPUT_DIM}')
print(f'_inputs: {_inputs.shape}')
print('')


with base_layer.JaxContext.new_context():
  _prng_key = jax.random.PRNGKey(seed=RANDOM_SEED)
  _prng_key, _subkey = jax.random.split(_prng_key)
  _pax_initial_vars = g_adapter_layer.init(
      {base_layer.PARAMS: _prng_key, base_layer.RANDOM: _subkey}, _inputs
  )
  _pax_initial_vars = NestedMap(_pax_initial_vars)
  _outputs = g_adapter_layer.apply(
      _pax_initial_vars, _inputs, rngs={base_layer.RANDOM: _subkey}
  )

# After pushing the input matrix through the adapter we expect that
# the output of the adapter should be BATCH_SIZE by NUM_VIRTUAL_TOKENS
# by MODEL_DIM or the token embedding size.
print(
    f'We expect {BATCH_SIZE} samples by {NUM_VIRTUAL_TOKENS} tokens by'
    f' {MODEL_DIM} dim.'
)
print(f'_outputs: {_outputs.shape}')

We expect 8 samples by 40
_inputs: (8, 40)

We expect 8 samples by 10 tokens by 128 dim.
_outputs: (8, 10, 128)


## Where to go next

The output of the adapter layer can be concatenated with the input to the LLMs.
The details of how this is done will depend on the specific implementation, but
in general it should be quite straightforward.