In [3]:
import jax
import jax.ops
import jax.numpy as np

import numpy as onp

import flax
from flax import linen as nn
import optax


In [5]:
def get_text(fname):
  with open(fname, 'r') as reader:
    data = reader.read()

  return data

def id_bridge(iterable):
  """ provides mapping to and from ids """
  return ({elem: id for id, elem in enumerate(iterable)}, 
          {id: elem for id, elem in enumerate(iterable)})

In [7]:
# very simple example to test model can learn
data = 'abcd...abcd...'

def prep_data(data):
  chars = list(set(data))
  vocab_size = len(chars)
  char_to_id, id_to_char = id_bridge(chars)
  # data converted to ids
  data_id = [char_to_id[char] for char in data]
  return data_id, char_to_id, id_to_char 

In [8]:
data_id, char_to_id, id_to_char = prep_data(data)
data_id[:10]

[0, 4, 1, 2, 3, 3, 3, 0, 4, 1]

In [9]:

def encode(char):
  return nn.one_hot(char_to_id[char], len(char_to_id))

def decode(predictions, id_to_char):
  # for simplicity, pick the most likely character
  # this can be replaced by sampling weighted
  # by the probability of each character
  return id_to_char[int(np.argmax(predictions))]

In [18]:
class RNNCell(nn.Module):
    @nn.compact
    def __call__(self, state, x):
        x = np.concatenate([state, x])
        new_state = np.tanh(nn.Dense(state.shape[0])(x))
        return new_state
    
class ChaRNN(nn.Module):
    state_size:int
    vocab_size: int

    @nn.compact
    def __call__(self, state, i):
        x = nn.one_hot(i, self.vocab_size)
        new_state = []

        new_state_1 = RNNCell()(state[0], x)
        new_state_2 = RNNCell()(state[1], new_state_1)
        new_state_3 = RNNCell()(state[2], new_state_2)
        predictions = nn.softmax(nn.Dense(self.vocab_size)(new_state_3))
        return [new_state_1, new_state_2, new_state_3], predictions
    
    def init_state(self):
        return [np.zeros(self.state_size)] * 3
        

In [19]:
def sample(model, params, bridge, initial='', max_length=100):
    """
    Sample from the model by greedily selecting next characters
    """
    char_to_id, id_to_char = bridge
    state = model.init_state()
    output = initial

    for char in initial[:-1]:
        state, _ = model.apply(params, state, char_to_id[char])
    
    next_char = initial[-1]
    for i in range(max_length):
        state, predictions = model.apply(params, state, char_to_id[next_char])
        next_char = decode(predictions, id_to_char)
        output += next_char

        return output
        

In [20]:
def sample(model, params, bridge, initial='', max_length=100):
  """
  Sample from the model by greedily selecting next characters

  To do: make more efficient by JIT-ing
  """
  char_to_id, id_to_char = bridge
  state = model.init_state()
  output = initial
  for char in initial[:-1]:
    state, _ = model.apply(params, state, char_to_id[char])

  next_char = initial[-1]
  for i in range(max_length):
    state, predictions = model.apply(params, state, char_to_id[next_char])
    next_char = decode(predictions, id_to_char)
    output += next_char

  return output

In [21]:
state_size = 8

key, subkey = jax.random.split(key)
model = ChaRNN(state_size, len(char_to_id))
params = model.init(subkey, model.init_state(), 0)

print(f"Model state size: {model.state_size}, vocab size: {model.vocab_size}")

# run a single example through the model to test that it works
new_state, predictions = model.apply(params, model.init_state(), 0)
assert predictions.shape[0] == model.vocab_size

Model state size: 8, vocab size: 5


In [22]:
# calling sample on random model leads to random output
sample(model, params, (char_to_id, id_to_char), 'abc', max_length=10)

'abc..dddda.a.'

In [None]:
def f(c, x, h):
    model
    

jax.lax.scan()