## Dataset

In [3]:
%load_ext autoreload
%autoreload 2
import transform

vocab = "0123456789+="
char_lookup_table = transform.CharacterTable(vocab)
list(char_lookup_table.generate_samples(5, 2))



[('12+84', '=96'),
 ('19+19', '=38'),
 ('76+724', '=800'),
 ('1+955', '=956'),
 ('79+331', '=410')]

In [3]:
batch = char_lookup_table.get_batch(12)
print(f"inputs shape: {batch['query'].shape}")
print(f"targets shape: {batch['answer'].shape}")
print(f"shape: batch_size, sequence_length, vocab_size")


inputs shape: (12, 8, 14)
targets shape: (12, 8, 14)
shape: batch_size, sequence_length, vocab_size


In [4]:
import jax.numpy as jnp

labels = batch['answer']
eos_row = labels[:, :, char_lookup_table.eos_id] # as we are dealing with a one hot encoded variable and we now the eos_id, we can select the column
# representing the eos_id and find it's argmax. Because if we find it's argmax, we know everything after that index will be padded.

print(labels[0, :]) # looking at a single sample it becomes more obvious. The eos_id = 1, in addition, the pad token is 0. 
# so if you look down the first column until you see the first 1 (i.e. the row containing the eos token), than we know that everything after 
# that row AND in the pad token column (i.e. column 0) should be equal to 1.


[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [5]:
eos_idx = jnp.argmax(eos_row, axis=-1) # get the first occurence when we get the end of sentence id


In [6]:
eos_idx

DeviceArray([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], dtype=int32)

In [7]:

jnp.where(
    eos_row[jnp.arange(eos_row.shape[0]), eos_idx],
    eos_idx + 1, # the +1 makes sure we include the eos token, which will also be needed during inference to know when to stop
    labels.shape[1] # if no eos id is present, use the entire sequence as target
)

DeviceArray([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

lets take an example. If the inputs = `120+123` than the target = `=243`
in addition, lets say we have the following char_to_idx mapping: 
```
{
    0: '_',
    1: '|',
    2: '+',
    3: '0',
    4: '1',
    5: '2',
    6: '3',
    7: '4',
    8: '5',
    9: '6',
    10: '7',
    11: '8',
    12: '9',
    13: '='
 }
```
the target becomes 
```
[
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], <-- =
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], <-- 2
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], <-- 4
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], <-- 3
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- |
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], <-- _
]
```

In [8]:
labels[0]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)

In [9]:
import jax.numpy as jnp

from transform import get_sequence_lengths

def mask_sequences(sequence_batch: jnp.ndarray, lengths: jnp.ndarray) -> jnp.ndarray:
  """Sets positions beyond the length of each sequence to 0."""
  return sequence_batch * (
      lengths[:, jnp.newaxis] > jnp.arange(sequence_batch.shape[1])[jnp.newaxis])[..., jnp.newaxis]

lengths = get_sequence_lengths(labels, char_lookup_table.eos_id)
masked_labels = mask_sequences(labels, lengths)


# note that in the masked_labels, all padded rows are set to 0
# whereas the labels still contain the padded one_hot encoded elements
masked_labels[0], labels[0]



(DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
              [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],            dtype=float32),
 array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],


In [10]:

labels.shape
lengths = get_sequence_lengths(labels, char_lookup_table.eos_id)
lengths_expanded = lengths[:, jnp.newaxis] # from shape (batch_size,) to (batch_size, 1)
jnp.arange(labels.shape[1]) # array for each element in the sequence. shape = (sequence_length,)
sequence_expanded = jnp.arange(labels.shape[1])[jnp.newaxis] # array for each element in the sequence. shape = (1, sequence_length)

# shapes are (batch_size, 1), (1, sequence_length)
print(lengths_expanded.shape, sequence_expanded.shape)

# if we compare them using the > operation, we are creating a boolean matrix
# which returns true for each element where the sequence length > labels index
print(lengths_expanded)
print(sequence_expanded)
print(1.0 * (lengths_expanded > sequence_expanded)) # multiply by 1.0 to turn boolean into float
masks = 1.0 * (lengths_expanded > sequence_expanded)
# masks.shape = (batch_size, sequence_size)
masks.shape 

(12, 1) (1, 8)
[[5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]
 [5]]
[[0 1 2 3 4 5 6 7]]
[[1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]]


(12, 8)

# The Model


inputs --> encoder --> embedding_representation_input_domain --> decoder --> embedding_representation_target_domain --> dense --> predictions


In [11]:
import flax.linen as nn

my interpretation of the nn.scan functionality is as follows
- nn.scan can be seen as a for loop with state carried over at each step
- the first argument is the operation that we want to apply (in our case, we want to feed each step through the lstmcell, so the operation is the nn.LSTMCell).
- We specify the variable_broadcast. This value is shared across each step and has no dependencies. Recall that the weights of an LSTMCell are shared. As such, we can provide these weights at each time step. Stated differently, we are broadcasting them at each time step
- we set split_rngs for the params to false. We don't want to create random parameters at each time step
- now lets move over to the axes. In the aforementioned documentation, I'm constantly talking about timesteps. The axes is what allows me to talk about it in this way. Remember, the inputs are of shape (batch_size, timesteps, n_features). by specifying in_axis=1, we are saying, apply this function over the first dimension (i.e. the timesteps/sequence axis).
- out_axis specifies the axis where the result will be added to (i.e. the dimension to place the results). As such, the output will have an extra dimension/axis at the position indicated by the `out_axis` (e.g. in the case of one sample with 10 output features, the shape would be (1, 10). with out_axis=1 the result would become (1, 1, 10))
 

In [12]:
# Example taken from flax documentation
# --> https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.scan.html?highlight=flax%20scan
import flax
import flax.linen as nn
from jax import random

class SimpleScan(nn.Module):
  @nn.compact
  def __call__(self, c, xs):
    LSTM = nn.scan(nn.LSTMCell,
                   variable_broadcast="params",
                   split_rngs={"params": False},
                   in_axes=1,
                   out_axes=1)
    return LSTM()(c, xs)

batch_size, seq_len, in_feat, out_feat = 16, 20, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat)

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)
out_val.shape 
# for each sample in the the batch (16 in total), at each time step (20 in total)
# we have the knowledge 'encoded' in a 5 dimensional vector


  new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs))
  jax.tree_map(get_scopes_inner, attrs)
  scopes, treedef = jax.tree_flatten(scope_tree)
  group[col] = jax.tree_map(lambda x: x, xs[col])
  lengths = jax.tree_map(find_length, in_axes, args)
  leaves = jax.tree_leaves(x)
  lengths = set(jax.tree_leaves(lengths))
  xs = jax.tree_map(transpose_to_front, in_axes, args)
  return jax.tree_map(trans, xs)
  carry_avals = jax.tree_map(
  scan_avals = jax.tree_map(
  in_avals, in_tree = jax.tree_flatten(input_avals)
  xs = jax.tree_map(lambda ax, arg, x: (arg if ax is broadcast else x),
  variables = jax.tree_map(lambda x: x, variables)
  new_args, new_kwargs = jax.tree_map(
  new_attrs = jax.tree_map(set_scopes_inner, attrs)
  jax.tree_leaves(tree)))
  ys = jax.tree_map(lambda ax, y: (y if ax is broadcast else ()),
  broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)
  ys 

(16, 20, 5)

In [13]:
init_carry[0].shape
out_carry[1].shape


(16, 5)

## Build the encoder part of the model

def select_carried_state(new_state, old_state):
    return jnp.where(is_eos[:, np.newaxis], old_state, new_state)



initial_

In [4]:
import functools
from typing import Any, Tuple

from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np

Array = Any
PRNGKey = Any

class EncoderLSTM(nn.Module):
  """EncoderLSTM Module wrapped in a lifted scan transform."""
  eos_id: int

  @functools.partial(
      nn.scan,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry: Tuple[Array, Array],
               x: Array) -> Tuple[Tuple[Array, Array], Array]:
    """Applies the module."""
    lstm_state, is_eos = carry
    new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
    # Pass forward the previous state if EOS has already been reached.
    def select_carried_state(new_state, old_state):
      return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
    # LSTM state is a tuple (cell_state, hidden_state).
    # we need to select what state to carry for both 
    # the cell_state as the hidden_state
    carried_lstm_state = tuple(
        select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
    # Update `is_eos`.
    is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
    return (carried_lstm_state, is_eos), y

  @staticmethod
  def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
    # Use a dummy key since the default state init fn is just zeros.
    return nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), (batch_size,), hidden_size)


class Encoder(nn.Module):
  """LSTM encoder, returning state after finding the EOS token in the input."""
  hidden_size: int
  eos_id: int

  @nn.compact
  def __call__(self, inputs: Array):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]
    lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
    init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
    # We use the `is_eos` array to determine whether the encoder should carry
    # over the last lstm state, or apply the LSTM cell on the previous state.
    init_is_eos = jnp.zeros(batch_size, dtype=bool)
    init_carry = (init_lstm_state, init_is_eos)
    (final_state, _), _ = lstm(init_carry, inputs)
    return final_state

In [5]:
nn.LSTMCell().initialize_carry(jax.random.PRNGKey(0), (16, ), 12)

(DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0

In [6]:
xs[:, eos_idx].shape

NameError: name 'xs' is not defined

In [7]:
batch.keys()

NameError: name 'batch' is not defined

In [8]:
batch = char_lookup_table.get_batch(16) 
print(batch['query'].shape)
print(batch['query'][:, eos_idx].shape)
print(char_lookup_table.decode_onehot(batch['query'][:1]))
print(batch['query'][0][:,eos_idx])
char_lookup_table._char_indices

# char_lookup_table.decode(batch['answer'])

TypeError: get_batch() missing 1 required positional argument: 'step'

In [9]:
# in the network we initialize the is_eos with all zeros
# then we compute the logical_or based on the input sequence where the id == eos_id
# Finally, we take the xor, updating the is_eos matrix to contain ones, where the 
# is_eos matrix is one or the input_matrix is one 
jnp.logical_or(jnp.array([1,0,1,0]), jnp.array([1,1,0, 0]))

DeviceArray([ True,  True,  True, False], dtype=bool)

In [10]:
batch_size = 32
hidden_size = 512
batch = char_lookup_table.get_batch(batch_size)

rng = jax.random.PRNGKey(0)

encoder_model = Encoder(hidden_size, char_lookup_table.eos_id)
example = jnp.ones((1, char_lookup_table.max_input_len, char_lookup_table.vocab_size), jnp.float32)
variables = encoder_model.init(rng, example)
model_params = variables['params']

TypeError: get_batch() missing 1 required positional argument: 'step'

In [11]:
batch['query'].shape

NameError: name 'batch' is not defined

In [12]:

encoding = encoder_model.apply({
    "params": model_params
}, batch['query'])

NameError: name 'encoder_model' is not defined

In [13]:
encoding[0].shape, encoding[1].shape

NameError: name 'encoding' is not defined

In [14]:
batch.keys()

NameError: name 'batch' is not defined

# Decoder Model

In [15]:

class DecoderLSTM(nn.Module):
  """DecoderLSTM Module wrapped in a lifted scan transform.
  Attributes:
    teacher_force: See docstring on Seq2seq module.
    vocab_size: Size of the vocabulary.
  """
  teacher_force: bool
  vocab_size: int

  @functools.partial(
      nn.scan,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False, 'lstm': True})
  @nn.compact
  def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
    """Applies the DecoderLSTM model."""
    lstm_state, last_prediction = carry
    if not self.teacher_force:
      x = last_prediction
    lstm_state, y = nn.LSTMCell()(lstm_state, x)
    logits = nn.Dense(features=self.vocab_size)(y)
    # Sample the predicted token using a categorical distribution over the
    # logits.
    categorical_rng = self.make_rng('lstm')
    predicted_token = jax.random.categorical(categorical_rng, logits)
    # Convert to one-hot encoding.
    prediction = jax.nn.one_hot(
        predicted_token, self.vocab_size, dtype=jnp.float32)

    return (lstm_state, prediction), (logits, prediction)


class Decoder(nn.Module):
  """LSTM decoder.
  Attributes:
    init_state: [batch_size, hidden_size]
      Initial state of the decoder (i.e., the final state of the encoder).
    teacher_force: See docstring on Seq2seq module.
    vocab_size: Size of the vocabulary.
  """
  init_state: Tuple[Any]
  teacher_force: bool
  vocab_size: int

  @nn.compact
  def __call__(self, inputs: Array) -> Tuple[Array, Array]:
    """Applies the decoder model.
    Args:
      inputs: [batch_size, max_output_len-1, vocab_size]
        Contains the inputs to the decoder at each time step (only used when not
        using teacher forcing). Since each token at position i is fed as input
        to the decoder at position i+1, the last token is not provided.
    Returns:
      Pair (logits, predictions), which are two arrays of respectively decoded
      logits and predictions (in one hot-encoding format).
    """
    lstm = DecoderLSTM(teacher_force=self.teacher_force,
                       vocab_size=self.vocab_size)
    init_carry = (self.init_state, inputs[:, 0])
    _, (logits, predictions) = lstm(init_carry, inputs)
    return logits, predictions

In [16]:

d = Decoder(init_state=encoding, teacher_force=False, vocab_size=char_lookup_table.vocab_size)

NameError: name 'encoding' is not defined

In [17]:
init_decoder_input = char_lookup_table.one_hot(char_lookup_table.encode('=')[0:1])

In [18]:

init_decoder_inputs = jnp.tile(init_decoder_input,
                                (batch['query'].shape[0], char_lookup_table.max_output_len, 1))

NameError: name 'batch' is not defined

In [19]:
init_decoder_input.shape, init_decoder_inputs.shape

NameError: name 'init_decoder_inputs' is not defined

In [20]:
key_1, key_2 = jax.random.split(jax.random.PRNGKey(0))
decoder_params = d.init({"params":key_1, "lstm": key_2}, init_decoder_inputs)
# d.apply()

NameError: name 'd' is not defined

In [21]:
lstm_rng = jax.random.PRNGKey(0)
step = 200
# we need a random key as our decoder samples 
# using a categorical distibution based on the logits
lstm_key = jax.random.fold_in(lstm_rng, step) 

logits, predictions = d.apply(decoder_params, init_decoder_inputs, rngs={'lstm': lstm_key})

NameError: name 'd' is not defined

In [None]:
logits.shape, predictions.shape

NameError: name 'logits' is not defined

In [22]:
# model isn't trained, but here are the results (query, target, prediction)
print(char_lookup_table.decode_onehot(batch['query']))
print(char_lookup_table.decode_onehot(batch['answer']))
print(char_lookup_table.decode_onehot(predictions))

NameError: name 'batch' is not defined

# The model

In [23]:
class Seq2seq(nn.Module):

    teacher_force: bool
    hidden_size: int
    vocab_size: int
    eos_id: int = 1

    @nn.compact
    def __call__(self, encoder_inputs, decoder_inputs):
        initial_decoder_state = Encoder(
            hidden_size=self.hidden_size,
            eos_id=self.eos_id
        )(encoder_inputs)

        logits, predictions = Decoder(
            init_state=initial_decoder_state,
            teacher_force=self.teacher_force,
            vocab_size=self.vocab_size
        )(decoder_inputs[:, :-1])
        return logits, predictions

    # @nn.compact
    # def __call__(self, encoder_inputs, decoder_inputs):
    #     initial_decoder_state = Encoder(
    #         hidden_size=self.hidden_size,
    #         eos_id=self.eos_id
    #     )(encoder_inputs)

    #     logits, predictions = Decoder(
    #         init_state=initial_decoder_state,
    #         teacher_force=self.teacher_force,
    #         vocab_size=self.vocab_size
    #     )(decoder_inputs[:, :-1])
    #     return logits, predictions


In [24]:
teacher_force = False
hidden_size = 512
eos_id = char_lookup_table.eos_id
vocab_size = char_lookup_table.vocab_size
model = Seq2seq(
    teacher_force=teacher_force,
    hidden_size=hidden_size,
    vocab_size=vocab_size
)

In [36]:
rng = jax.random.PRNGKey(0)
params_key, lstm_key = jax.random.split(rng)
max_input_sequence_len = char_lookup_table.max_input_len
max_output_sequence_len = char_lookup_table.max_output_len
variables = model.init(
    {
        "params": params_key,
        "lstm": lstm_key
    },
    jnp.ones((batch_size, max_input_sequence_len, vocab_size), dtype=jnp.float32),
    jnp.ones((batch_size, max_output_sequence_len, vocab_size), dtype=jnp.float32)
)

In [37]:
variables.keys()

frozen_dict_keys(['params'])

In [38]:
model.apply(variables, batch['query'], batch['answer'], rngs={
    "lstm": lstm_key
})

(DeviceArray([[[ 0.08307394, -0.06417654,  0.11232382, ...,  0.04796918,
                 0.08337983, -0.04808454],
               [-0.04887169,  0.02889351,  0.08376509, ..., -0.0573354 ,
                 0.02569693, -0.01780605],
               [ 0.07617846, -0.06507105,  0.15715244, ...,  0.07221642,
                 0.12518902, -0.0459436 ],
               ...,
               [ 0.08691359, -0.09904549,  0.05844625, ...,  0.03085802,
                -0.13066357, -0.01850299],
               [ 0.17952195,  0.04801203,  0.10827539, ..., -0.02468754,
                 0.0370799 , -0.06572822],
               [ 0.03389015,  0.00849534,  0.0068223 , ...,  0.00689863,
                 0.0393358 , -0.11326101]],
 
              [[ 0.09038381, -0.08036304,  0.10141592, ...,  0.02821084,
                 0.10218308, -0.02035185],
               [ 0.13380422,  0.05083488,  0.04915942, ..., -0.10524669,
                 0.10610627,  0.05437368],
               [ 0.1284994 ,  0.04025817,  0.1039

In [42]:

# model.apply(variables, batch['query'], batch['answer'], rngs={
#     "lstm": lstm_key
# })
# during inference you provide the start_of_sentence token
# which in our case is the '=' character 
batch['answer'].shape # <-- input_shape

(32, 8, 14)

In [40]:
init_decoder_inputs.shape

(32, 6, 14)

In [41]:
init_decoder_inputs[0], char_lookup_table._char_indices

(DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],            dtype=float32),
 {'+': 2,
  '0': 3,
  '1': 4,
  '2': 5,
  '3': 6,
  '4': 7,
  '5': 8,
  '6': 9,
  '7': 10,
  '8': 11,
  '9': 12,
  '=': 13})

## Training

# Saving

In [34]:

import functools
from typing import Callable, Dict, List

import absl

import flax
from flax import linen as nn
from flax.metrics import tensorboard
import flax.optim

import jax
from jax import numpy as jnp
from jax.experimental import jax2tf
from mlteacher.mlops.transform import CharacterTable
import numpy as np

import tensorflow as tf


# The transformed feature names

# Type abbreviations: (B is the batch size)
_Array = np.ndarray
_InputBatch = Dict[str,
                   _Array]  # keys are _FEATURE_KEYS_XF and values f32[B, 1]
_LogitBatch = _Array  # of shape f32[B, 3]
_LabelBatch = _Array  # of shape int64[B, 1]
_Params = Dict[str, _Array]


class _SavedModelWrapper(tf.train.Checkpoint):
  """Wraps a function and its parameters for saving to a SavedModel.
  Implements the interface described at
  https://www.tensorflow.org/hub/reusable_saved_models.
  This class contains all the code needed to convert a Flax model to a
  TensorFlow saved model.
  """

  def __init__(self,
               tf_graph: Callable[[_InputBatch], _Array],
               param_vars: Dict[str, tf.Variable]):
    """Builds the tf.Module.
    Args:
      tf_graph: a tf.function taking one argument (the inputs), which can be be
        tuples/lists/dictionaries of np.ndarray or tensors. The function may
        have references to the tf.Variables in `param_vars`.
      param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
        to be saved as the variables of the SavedModel.
    """
    super().__init__()
    # Implement the interface from
    # https://www.tensorflow.org/hub/reusable_saved_models
    self.variables = tf.nest.flatten(param_vars)
    self.trainable_variables = [v for v in self.variables if v.trainable]
    self._tf_graph = tf_graph

  @tf.function
  def __call__(self, inputs):
    return self._tf_graph(inputs)




from mlteacher.mlops import train, models

def save_model():
    rng = jax.random.PRNGKey(0)
    ctable = CharacterTable("0123456789=+")
    model = models.Seq2seq(teacher_force=False, hidden_size=512, vocab_size=ctable.vocab_size)
    params = train.get_initial_params(model, rng, ctable)

    batch_size = 1
    step = 1
    batch = ctable.get_batch(batch_size=batch_size, step=step)
    model.apply({"params": params}, batch['query'], batch['answer'], rngs={"lstm":jax.random.PRNGKey(1)})
    predict_fn = lambda params, input: model.apply({"params": params}, *input, rngs={"lstm":jax.random.PRNGKey(1)})
    tf_fn = jax2tf.convert(predict_fn, with_gradient=False, enable_xla=True)

    return tf_fn, params
tf_fn, trained_params = save_model()                        

  new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs))
  jax.tree_map(get_scopes_inner, attrs)
  scopes, treedef = jax.tree_flatten(scope_tree)
  group[col] = jax.tree_map(lambda x: x, xs[col])
  lengths = jax.tree_map(find_length, in_axes, args)
  leaves = jax.tree_leaves(x)
  lengths = set(jax.tree_leaves(lengths))
  xs = jax.tree_map(transpose_to_front, in_axes, args)
  return jax.tree_map(trans, xs)
  carry_avals = jax.tree_map(
  scan_avals = jax.tree_map(
  in_avals, in_tree = jax.tree_flatten(input_avals)
  xs = jax.tree_map(lambda ax, arg, x: (arg if ax is broadcast else x),
  variables = jax.tree_map(lambda x: x, variables)
  new_args, new_kwargs = jax.tree_map(
  new_attrs = jax.tree_map(set_scopes_inner, attrs)
  jax.tree_leaves(tree)))
  ys = jax.tree_map(lambda ax, y: (y if ax is broadcast else ()),
  broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)
  ys 

In [35]:

param_vars = tf.nest.map_structure(
    # Due to a bug in SavedModel it is not possible to use tf.GradientTape
    # on a function converted with jax2tf and loaded from SavedModel.
    # Thus, we mark the variables as non-trainable to ensure that users of
    # the SavedModel will not try to fine tune them.
    lambda param: tf.Variable(param, trainable=False),
    trained_params)
tf_graph = tf.function(
    lambda inputs, decoder_inputs: tf_fn(param_vars, (inputs, decoder_inputs)),
    autograph=False,
    experimental_compile=True)

signatures = {}
# This signature is needed for TensorFlow Serving use.
batch_size = 1
ctable = CharacterTable("0123456789+=")
ctable.get_batch(1, 1)['query'].shape
input_signatures = [
            tf.TensorSpec((batch_size,) + (ctable.max_input_len, ctable.vocab_size), tf.float32),
            tf.TensorSpec((batch_size,) + (ctable.max_output_len, ctable.vocab_size), tf.float32)
          ]
signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf_graph.get_concrete_function(input_signatures[0], input_signatures[1])

tf_model = _SavedModelWrapper(tf_graph, param_vars)

2022-10-20 19:56:40.363922: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-10-20 19:56:40.363962: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)


In [36]:
signatures

{'serving_default': <ConcreteFunction <lambda>(inputs, decoder_inputs) at 0x7F11267867F0>}

In [73]:
from datetime import datetime

In [74]:
def save_model(model, signatures, serving_dir):
    latest_model = datetime.now().strftime("%Y%m%d%H%M%S")
    serving_dir = os.path.join(serving_dir, latest_model)
    tf.saved_model.save(tf_model, serving_dir, signatures=signatures)

In [75]:
save_model(tf_model, signatures,  "../../../mlteacher/models/brain/")

  new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs))
  jax.tree_map(get_scopes_inner, attrs)
  scopes, treedef = jax.tree_flatten(scope_tree)
  group[col] = jax.tree_map(lambda x: x, xs[col])
  lengths = jax.tree_map(find_length, in_axes, args)
  leaves = jax.tree_leaves(x)
  lengths = set(jax.tree_leaves(lengths))
  xs = jax.tree_map(transpose_to_front, in_axes, args)
  return jax.tree_map(trans, xs)
  carry_avals = jax.tree_map(
  scan_avals = jax.tree_map(
  in_avals, in_tree = jax.tree_flatten(input_avals)
  xs = jax.tree_map(lambda ax, arg, x: (arg if ax is broadcast else x),
  variables = jax.tree_map(lambda x: x, variables)
  new_args, new_kwargs = jax.tree_map(
  new_attrs = jax.tree_map(set_scopes_inner, attrs)
  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)
  jax.tree_leaves(tree)))
  ys = jax.tree_map(lambda ax, y: (y if ax is broadcast else ()),
  broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)
  ys 

In [9]:
import tensorflow as tf
loaded_model = tf.saved_model.load("/workspaces/MLOpsWithJax/src/mlteacher/models/brain/20221020124847")

In [10]:
serving_default = loaded_model.signatures['serving_default']

In [11]:
from mlteacher import config

ctable = transform.CharacterTable(
    '0123456789+= ', config.TrainConfig.max_len_query_digit)
# ctable = CharacterTable("0123456789=+")
b = ctable.get_batch(1, 2)
q, a = b['query'], b['answer']

In [12]:
q.tolist()

[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]

In [13]:
import jax.numpy as jnp
decoder_input_a = jnp.zeros_like(a)
decoder_input_a[:, :6, :].tolist()

[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]

In [14]:
serving_default(inputs=q, decoder_inputs=a)

{'output_0': <tf.Tensor: shape=(1, 5, 14), dtype=float32, numpy=
 array([[[ 0.04452905, -0.02319412,  0.02610207,  0.03594843,
          -0.01173355, -0.08520072, -0.00829763,  0.10091518,
          -0.06154622,  0.01647665,  0.02164995, -0.01271162,
           0.00583454,  0.06530214],
         [ 0.00173595,  0.05773558,  0.14703731,  0.01773059,
           0.04543342, -0.1435977 , -0.09084816, -0.02898849,
           0.02558784,  0.05243998,  0.13085285, -0.1132612 ,
          -0.03834904, -0.00568175],
         [-0.04256568,  0.0139187 ,  0.06434848,  0.08148213,
           0.01206123, -0.1935127 ,  0.00626391,  0.03290744,
           0.00289503,  0.04359363,  0.13530481, -0.03515957,
          -0.01334131, -0.07012346],
         [-0.11977271,  0.08828357,  0.10907937, -0.06655134,
           0.00992293, -0.08000052,  0.0345731 ,  0.0150109 ,
          -0.06789728, -0.09775098,  0.07344852, -0.03305475,
           0.01745135,  0.01585141],
         [-0.030447  ,  0.09042929,  0.0235

In [None]:
MODEL_NAME=brain
docker run -t --rm -p 8501:8501 --mount type=bind,source=$PWD/models/$MODEL_NAME,target=/models/$MODEL_NAME/ -e MODEL_NAME=$MODEL_NAME -e --xla_cpu_compilation_enabled=true docker.io/tensorflow/serving:latest --xla_cpu_compilation_enabled=true &
curl -X POST -d @sample.json http://127.0.0.1:8501/v1/models/$MODEL_NAME:predict

In [2]:
import os
credential = DefaultAzureCredential()

class AzureMLConfig:
    subscription_id = os.environ.get("SUBSCRIPTION_ID", "0abb6ec5-9030-4b3f-af04-09183c688576")
    resource_group_name = os.environ.get("RESOURCE_GROUP", "csu-nl-intelligence")
    workspace_name = os.environ.get("AZUREML_WORKSPACE_NAME", "mlpatterns")



ml_client = MLClient(
    credential=credential,
    subscription_id=AzureMLConfig.subscription_id,
    resource_group_name=AzureMLConfig.resource_group_name,
    workspace_name=AzureMLConfig.workspace_name,
)



Class RegistryOperations: This is an experimental class, and may change at any time. Please see https://aka.ms/azuremlexperimental for more information.


In [58]:

from azure.ai.ml import command
from azure.ai.ml import Input

model = Model(
    path="/workspaces/MLOpsWithJax/src/mlteacher/models/brain",
    type="custom_model",
    name="brain",
    version="1",
    description="A JAX sequence2sequence model served with Azure ML and Tensorflow Serving"
)

registered_model = ml_client.models.create_or_update(model=model)

In [75]:
os.path.exists("/workspaces/MLOpsWithJax/src/mlteacher/models/brain")

True

In [66]:
# Creating a unique endpoint name with current datetime to avoid conflicts
import datetime
# import required libraries
from azure.ai.ml import MLClient
from azure.ai.ml.entities import (
   ManagedOnlineEndpoint,
   ManagedOnlineDeployment,
   Model,
   Environment,
   CodeConfiguration,
)
from azure.identity import DefaultAzureCredential

# online_endpoint_name = "endpoint-" + datetime.datetime.now().strftime("%m%d%H%M%f")
online_endpoint_name = "jax-online-endpoint"
# create an online endpoint
endpoint = ManagedOnlineEndpoint(
    name=online_endpoint_name,
    description="this is a sample online endpoint",
    auth_mode="key",
    tags={"foo": "bar"},
)