In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
import tensorflow as tf
from tasks import chef_config
from tasks import recipe_reader

In [34]:
# example of how to use the recipe reader
path_to_text_file = "./recipes/recipes.txt"
number_of_characters = 50000
number_of_characters_test = 10000


text = (
    open(path_to_text_file, "rb")
    .read()
    .decode(encoding="utf-8")
)

train_text = text[:number_of_characters]
test_text = text[-number_of_characters_test:]

imitation_ds_training, stoi, itos = recipe_reader.imitate_chefs(train_text, chef_config.ChefConfig())
imitation_ds_exam, _, _ = recipe_reader.imitate_chefs(test_text, chef_config.ChefConfig())
for x, y in imitation_ds_training.take(1):
    print(x.shape)
    print(y.shape)
    print(recipe_reader.decode(x, itos)[0])
    print(recipe_reader.decode(y, itos)[0])





(32, 128)
(32, 128)
tf.Tensor(b'eam cheese add the milk and beat smooth add lemon juice pineapple and nuts beat well pour into graham cracker crust and chi', shape=(), dtype=string)
tf.Tensor(b'am cheese add the milk and beat smooth add lemon juice pineapple and nuts beat well pour into graham cracker crust and chil', shape=(), dtype=string)


In [35]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import chex


class PositionalEncoding(nn.Module):
    # positional encoding enables the model to know the position of the token in the sequence
    # we do this by adding a positional encoding vector to the feature vector.
    # normally, we would add it to an embedding vector, but since we working with time series data,
    # we don't have an embedding vector, so we add it to the feature vector.

    max_seq_len: int
    brain_size: int

    def setup(self):
        # first, we create a matrix of shape (max_seq_len, brain_size)
        pos_enc = jnp.zeros((self.max_seq_len, self.brain_size))

        raw_position = jnp.arange(0, self.max_seq_len)[:, jnp.newaxis]
        chex.assert_shape(raw_position, (self.max_seq_len, 1))
        # pos / (10000 ** (2 * i / brain_size))
        divider = 10000 ** (2 * jnp.arange(0, self.brain_size, 2) / self.brain_size)

        sin_encoding = jnp.sin(raw_position / divider)
        cos_encoding = jnp.cos(raw_position / divider)

        # overwrite the feature dimension.
        pos_enc = pos_enc.at[:, 0::2].set(sin_encoding)
        pos_enc = pos_enc.at[:, 1::2].set(cos_encoding)

        # add batch dimension
        self.pos_enc = pos_enc[jnp.newaxis, :, :]

    def __call__(self, x, add_to_inputs=True):
        # select all elements in batch, up to timeseries length, and all features
        if add_to_inputs:
            return x + self.pos_enc[:, :x.shape[1], :]
        return self.pos_enc[:, :x.shape[1], :]


class StructureInformation(nn.Module):
    max_seq_len: int
    brain_size: int
    chef_vocab_size: int

    @nn.compact
    def __call__(self, x, training=True):
        # x: (batch_size, seq_len)
        # we want to add a layer norm to the features dimension
        B, T = jnp.shape(x)
        
        token_embeddings = nn.Embed(self.chef_vocab_size, self.brain_size)(x) # (batch_size, seq_len, brain_size)
        position_embeddings = nn.Embed(self.max_seq_len, self.brain_size)(jnp.arange(0, T)) 

        x = token_embeddings + position_embeddings
        
        # x = nn.LayerNorm()(x)
        # # we want to add a dropout layer
        # x = nn.Dropout(0.1)(x, deterministic=not training)

        # # convert to (batch_size, seq_len, brain_size)
        # x = nn.Dense(features=self.brain_size)(x)

        # # add positional encoding
        # x = PositionalEncoding(self.max_seq_len, self.brain_size)(x)
        return x




# Could have been a function, but we want to use the @nn module
class Thought(nn.Module):
    # Scaled dot product attention
    brain_size: int

    def setup(self):
        self.questions_to_improve_knowledge = nn.Dense(features=self.brain_size) # queries
        self.knowledge_index = nn.Dense(features=self.brain_size) # keys
        self.information_in_knowledge_index = nn.Dense(features=self.brain_size) # values


    @nn.compact
    def __call__(self, x): #given_a_question, knowledge_index, information_in_knowledge_index):
        B, T, C = jnp.shape(x)

        formulated_question = self.questions_to_improve_knowledge(x)
        knowledge_index = self.knowledge_index(x)
        information_in_indices = self.information_in_knowledge_index(x)

        attention_to_relevant_indexes = jnp.matmul(formulated_question, jnp.swapaxes(knowledge_index, -2, -1)) / jnp.sqrt(self.brain_size)
        # mask out the future tokens
        mask = jnp.tril(jnp.ones((T, T))) # tril = lower triangle
        attention_to_relevant_indexes = jnp.where(mask == 0, -jnp.inf, attention_to_relevant_indexes)

        
        probability_knowledge_index_is_relevant = nn.softmax(attention_to_relevant_indexes, axis=-1)
        answer_to_question_based_on_information_in_knowledge_index = jnp.matmul(probability_knowledge_index_is_relevant, information_in_indices)
        # return both for debugging purposes
        return answer_to_question_based_on_information_in_knowledge_index, attention_to_relevant_indexes


In [36]:
class BrainStorm(nn.Module):
    n_ideas: int
    brain_size: int

    def setup(self) -> None:
        chex.assert_equal(self.brain_size % self.n_ideas, 0)
        self.idea_size = self.brain_size // self.n_ideas

        self.thoughts = [Thought(self.idea_size) for _ in range(self.n_ideas)]
        
        self.filter_interesting_thoughts = nn.Dense(
            features=self.brain_size,
            kernel_init=nn.initializers.xavier_uniform(),
            use_bias=False,
        )

    def __call__(self, x):
        thoughs = jnp.concatenate([thought(x)[0] for thought in self.thoughts], axis=-1)
        interesting_thoughts = self.filter_interesting_thoughts(thoughs)
        return interesting_thoughts


In [37]:

class ProjectIdeas(nn.Module):
    brain_size: int
    dropout_rate: float

    @nn.compact
    def __call__(self, x, training=True):
        x = nn.Dense(
            features=4 * self.brain_size,
            kernel_init=nn.initializers.xavier_uniform(),
            use_bias=False,
        )(x)
        x = nn.relu(x)
        x = nn.Dense(
            features=self.brain_size,
            kernel_init=nn.initializers.xavier_uniform(),
            use_bias=False,
        )(x)
        x = nn.Dropout(self.dropout_rate)(x, deterministic=not training)
        return x


class CreativityBlock(nn.Module):
    brain_size: int
    n_ideas: int
    dropout_rate: float

    def setup(self):
        # ln -> mha -> projection -> ln
        # Andrej Karpathy min 1:35:30 https://www.youtube.com/watch?v=kCc8FmEb1nY --> move layernorm before mha
        self.normalize_content1 = nn.LayerNorm()
        self.brainstrom = BrainStorm(self.n_ideas, self.brain_size)
        self.normalize_content2 = nn.LayerNorm()
        self.project_ideas = ProjectIdeas(self.brain_size, self.dropout_rate)
    
    def __call__(self, x, training=True):
        # lha -> brainstrom
        x = self.normalize_content1(x)
        x = self.brainstrom(x)

        # residual connection
        x = x + x

        # ln -> projection
        x = self.normalize_content2(x)
        x = self.project_ideas(x, training=training)

        # residual connection
        x = x + x
        return x

class IdeaIteration(nn.Module):
    n_moldings: int
    brain_size: int
    n_ideas: int
    dropout_rate: float

    def setup(self):
        self.creativity_blocks = [
            CreativityBlock(self.brain_size, self.n_ideas, self.dropout_rate)
            for _ in range(self.n_moldings)
        ]
    
    def __call__(self, x, training=True):
        for creativity_block in self.creativity_blocks:
            x = creativity_block(x, training=training)
        return x
    

class IdeaArticulation(nn.Module):
    max_seq_len: int
    brain_size: int
    dropout_rate: float
    chef_vocabulary_size: int

    @nn.compact
    def __call__(self, x, training=True):
        # flatten
        x = jnp.reshape(x, (x.shape[0], -1)) 
        x = nn.LayerNorm()(x)
        x = nn.Dense(features=self.brain_size, kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.relu(x)
        x = nn.Dropout(self.dropout_rate)(x, deterministic=not training)
        x = nn.Dense(features=self.max_seq_len*self.chef_vocabulary_size, kernel_init=nn.initializers.xavier_uniform())(x)
        x = jnp.reshape(x, (x.shape[0], self.max_seq_len, self.chef_vocabulary_size))
        return x

In [38]:
class ChefBrain(nn.Module):
    max_seq_len: int
    brain_size: int
    n_ideas: int
    n_moldings: int
    dropout_rate: float
    chef_vocabulary_size: int

    @nn.compact
    def __call__(self, x, training=True):
        x = StructureInformation(self.max_seq_len, self.brain_size, self.chef_vocabulary_size)(x, training=training)
        x  = IdeaIteration(self.n_moldings, self.brain_size, self.n_ideas, self.dropout_rate)(x, training=training)
        x = IdeaArticulation(self.max_seq_len, self.brain_size, self.dropout_rate, self.chef_vocabulary_size)(x, training=training)
        return x

In [45]:
import optax
from flax.training import train_state

def prepare_chef_for_training(chef_config: chef_config.ChefConfig):
    rng = jax.random.PRNGKey(chef_config.kitchen_seed)

    model_rng, dropout_rng = jax.random.split(rng)
    total_steps = chef_config.n_times_to_imitate_chefs * chef_config.n_recipes_to_sample
    # warmup_cosine_decay_scheduler
    learn_by_scanning_followed_by_deepdive = optax.warmup_cosine_decay_schedule(init_value=0.0001, peak_value=0.001,
                                                                   warmup_steps=int(total_steps*0.2),
                                                                   decay_steps=total_steps, end_value=0.00001)
    tx = optax.adam(learn_by_scanning_followed_by_deepdive)

    model = ChefBrain(
        max_seq_len=chef_config.max_seq_len,
        brain_size=chef_config.brain_size,
        n_ideas=chef_config.n_ideas,
        n_moldings=chef_config.n_moldings,
        dropout_rate=chef_config.dropout_rate,
        chef_vocabulary_size=chef_config.chef_vocab_size
    )

    variables = model.init(
        {
            'params': model_rng,
            'dropout': dropout_rng
        },
        jnp.ones((chef_config.batch_size, chef_config.max_seq_len), dtype=jnp.int32)
    )
    
    state = train_state.TrainState.create(
            apply_fn = model.apply,
            tx=tx,
            params=variables['params']
        )
    return state

In [61]:
def train_step(state, x, y, dropout_rng=None):
    dropout_rng = jax.random.fold_in(dropout_rng, state.step)
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, x, rngs={"dropout": dropout_rng})
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, (loss, logits)

@jax.jit
def train_step_jitted(state, x, y, dropout_rng=None):
    dropout_rng = jax.random.fold_in(dropout_rng, state.step)
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, x, rngs={"dropout": dropout_rng})
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, (loss, logits)

    
@jax.jit
def eval_step_jitted(state, x, y, dropout_rng=None):
    logits = state.apply_fn({"params": state.params}, x, rngs={"dropout": dropout_rng}, training=False)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    return loss, logits


In [72]:
from orbax import checkpoint

def train_and_eval(chef, train_ds, test_ds, chef_config: chef_config.ChefConfig):
    rng = jax.random.PRNGKey(chef_config.kitchen_seed)
    orbax_checkpointer = checkpoint.PyTreeCheckpointer()
    options = checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
    checkpoint_manager = checkpoint.CheckpointManager(chef_config.chef_state_path, orbax_checkpointer, options)

    ckpt = {"model": chef}
    for replication in range(chef_config.n_times_to_imitate_chefs):
        for batch_idx, (recipe_instructions, recipe_by_master) in enumerate(train_ds.take(chef_config.n_recipes_to_sample).as_numpy_iterator()):
            rng, dropout_rng = jax.random.split(rng)
            recipe_by_master = jax.nn.one_hot(recipe_by_master, chef_config.chef_vocab_size)
            chef, (loss, logits) = train_step_jitted(chef, recipe_instructions, recipe_by_master, dropout_rng)

            if batch_idx % 100 == 0:
                print(f"replication: {replication}, batch: {batch_idx}, loss: {loss}")
    
        for batch_idx, (recipe_instructions, recipe_by_master) in enumerate(test_ds.take(chef_config.n_exam_recipes).as_numpy_iterator()):
            rng, dropout_rng = jax.random.split(rng)
            recipe_by_master = jax.nn.one_hot(recipe_by_master, chef_config.chef_vocab_size)
            loss, logits = eval_step_jitted(chef, recipe_instructions, recipe_by_master, dropout_rng)

        print(f"Evaluation --- replication: {replication}, batch: {batch_idx}, loss: {loss}")
        checkpoint_manager.save(replication, ckpt)
    return chef

In [64]:
config = chef_config.ChefConfig(chef_vocab_size=len(stoi.get_vocabulary()))
chef = prepare_chef_for_training(config)
trained_chef = train_and_eval(chef, imitation_ds_training, imitation_ds_exam, config)

replication: 0, batch: 0, loss: 3.4302899837493896
replication: 0, batch: 100, loss: 3.074660539627075
replication: 0, batch: 200, loss: 2.970586061477661
replication: 0, batch: 300, loss: 2.957038402557373
replication: 0, batch: 400, loss: 2.9560329914093018
Evaluation --- replication: 0, batch: 49, loss: 2.9171864986419678
replication: 1, batch: 0, loss: 2.946465015411377
replication: 1, batch: 100, loss: 2.9506020545959473
replication: 1, batch: 200, loss: 2.946089506149292
replication: 1, batch: 300, loss: 2.9453012943267822
replication: 1, batch: 400, loss: 2.917570114135742
Evaluation --- replication: 1, batch: 49, loss: 2.9020333290100098
replication: 2, batch: 0, loss: 2.9265527725219727
replication: 2, batch: 100, loss: 2.91237211227417
replication: 2, batch: 200, loss: 2.897219181060791
replication: 2, batch: 300, loss: 2.8907010555267334
replication: 2, batch: 400, loss: 2.8671491146087646
Evaluation --- replication: 2, batch: 49, loss: 2.8463873863220215
replication: 3, bat

In [70]:
dropout_rng = jax.random.PRNGKey(0)
for batch_idx, (x, y) in enumerate(imitation_ds_training.take(1).as_numpy_iterator()):
    y = jax.nn.one_hot(y, len(itos.get_vocabulary()))
    loss_eval, logits_eval = eval_step_jitted(chef, x, y, dropout_rng=dropout_rng)

In [71]:
recipe_reader.decode(jnp.argmax(jnp.squeeze(logits_eval), axis=-1), itos)

<tf.Tensor: shape=(32,), dtype=string, numpy=
array([b'gvgeurixr?jqzhhbj?oyimglmufzacutk|lrb ellehmdstvvdlwykrzodufisnlxgpp?hdfd??ogmamncnwhzizwwlwpwizcywnswyaie chisgwbgtuvvbe  ',
       b'gvieuvkcr?jozohcfm?ouumglbufurpstk|lhtpaul|hmastvvclckrzouufisn?qfpqahrbd ?odw mncstzikwxlw|wfzfywmugykibbnqpikmwqgfbdxiezn',
       b'kviuzkbxrzjozoqhfm?oujdmglbufzezszvbrbaul?zmdiisvdluk|zobufisnlxcpqacrzd??odt mncjtzikwxlw?izdy|nugyk?c|qpiktvqgdukebexe',
       b'kvi uvnflzjozhhhfm?oujmglt?fujcstztlyzkeuuezmditsvdlulxpzogufisnpnjoq?h bdh?odc mkcstzikuxln?wizcowfuwyriudoqpikmvqctbdvcoxn',
       b'gvieecihf?jozhhlfs?ouumgbbjfuazutkvlhzpszuehmdstvvdlcykyzodufksn?nmkx?hrd ?od zmncajtzi?wol?|wfzfywnswylibbqiijgwqgqb eie n',
       b'gvieuvihc?jazhhlbf?ohumgbbvqurosbay|ytcefu|hmhst?lcuayopgdfetksnpt|ox?hdtd ?oipzknckwtzfhuwl??wizfowfsdraiddwriwmjqc|qvxi zn',
       b'gvkeumkcr?wozhzjff?olidmghm?fnrbozztlrb all?zmdstzvdlgykysodufasnpnsoqscr?x?oi amncnwyhijeoln?wikcywnuwyr?edwpvfgvqgoz xbexu',
       b'

In [67]:
recipe_reader.decode(jnp.argmax(jnp.squeeze(logits_eval), axis=-1), itos)

<tf.Tensor: shape=(32,), dtype=string, numpy=
array([b'gviegmkxlrjozhhlfs?ouhdmgtmufbrvstw|lhbcallehmdstsvclwykpyodufsnpn|px?crad ?od zmncashzizwwln?wfzcywnuwyricboqiijmwqgtzdxie  ',
       b'gvieernjf?jozhhlbs?oujmgbnufurcstytlrzcauuehmdst?vpuuyxpzdglfksnpxfqr?hradh?oicamncnjtzizqwln|wfzfyrfsdfaiubvdiijmjqcfqv?caxn',
       b'gviuvicrzjozhzgbj?oujagtm?fujpszktlrb eluehmdsivvdluykyzofufiszpmtpr?jdad??ogmagkcnjyzizexlw?wizc|wnswya?edqhikgvqgtqvxcex ',
       b'kviurxbxlzjozhhjfm?ouhdmglm?fzavol|lrba l?hmdstsydluyk|yocufisnlntoq?crlj??og  mncshzikwolw?izdywnuwyk?ercqpijtvogfuxxbexn',
       b'zvieurxnxlzjozhhfm?ouhmgbm?fbacslydhbeaflehmdstsvdlzkpzofufisn?nmpx?c ad?od zmncstzi?doln?wfkcywnsgykivbqpijmwagfuvxiezn',
       b'zvieurkxr?jozhh fm?otumgbmubzcsty|lrzaxl|hmdstvvdlukrzodufisnptmpx?hd???odwzmxckwtzikdwln?izfywfsdrkivdnqpifmwqgbqvyiazn',
       b'zvieurifczjozhhlbm?ouumgbmvqnacsty|drzcaxu|hmdst?vcucyopzdfctksmlxcpxrhad ?odzzmnckjtzikqwln?wizfyrfsdrkidfqiijmwqcfbvui xn',
       b'gvi

In [69]:



prng = jax.random.PRNGKey(0)
prng, dropout_rng = jax.random.split(prng)
chef = prepare_chef_for_training(config)

In [29]:

for epoch in range(200):
    ckpt = {'model': chef}
    for batch_idx, (x, y) in enumerate(imitation_ds_training.take(200).as_numpy_iterator()):
        y = jax.nn.one_hot(y, len(itos.get_vocabulary()))
        chef, (loss, logits) = train_step_jitted(chef, x, y, dropout_rng=dropout_rng)
        if batch_idx % 100 == 0:
            print(f"epoch: {epoch}, batch: {batch_idx}, loss: {loss}")
    checkpoint_manager.save(batch_idx, ckpt)


epoch: 0, batch: 0, loss: 3.509216785430908


KeyboardInterrupt: 

In [19]:
jax.devices()

[CpuDevice(id=0)]

In [19]:
import os
os.listdir('/tmp/flax_ckpt/orbax/managed')

['31.orbax-checkpoint-tmp-1696064671479759', '29', '30']

In [23]:
ls_step = checkpoint_manager.latest_step()
restored_chef = checkpoint_manager.restore(ls_step)



In [29]:
restored_chef["model"]

{'opt_state': [{'count': array(0, dtype=int32),
   'mu': {'IdeaArticulation_0': {'Dense_0': {'bias': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0

In [37]:
reloaded_logits = chef.apply_fn({"params": restored_chef["model"]["params"]}, x, rngs={"dropout": dropout_rng})
recipe_reader.decode(jnp.argmax(jnp.squeeze(reloaded_logits), axis=-1), itos)

<tf.Tensor: shape=(32,), dtype=string, numpy=
array([b'cy cooeqpzjzuaaqejvejwlwdd w| igf?poysxwypcgmat ?pc|rqccvmcssmkfryggdeltkfuzbnit?zjvfr||rsmke zbnk|kef?yrkfjibslfatsbhhoc jn',
       b'fjvxeqpzwlkprjmjscjaaskq wrzqgevloymxqxycqoatebh|jsshxvopgrxtifstghttdznkfmdtn?? znvuxm|smzinajbnk|qfq?b kfjsfrqowjcumhazcmln',
       b'dyhdteq mwvbafsvpejakisqcx|xhwkr?oywycpatmateupn|rshynmxervirfs|gmkdrlyrumbtnxmhzuafxkeisnrskhbnkbkfw?c ?txiastnatrwhfyn?jn',
       b'd vvkoeqnzd?illge?vojadpsvy qcixq pyawycxytqdq p|zehjly?erlcxzivglqtlytlqmulmi??ziaenm|mhnzboh|knkl|?tkhakfbdgym|faxsdhrvn|bn',
       b'cvdboeypzjvilrqrppoj|vpgdvw|cice paybywx|nqlqr?ptvqhmkrskmtkrpwvzpndtunrfdrbmitecxz?fkqpgziuthulu|keqkhrpfbitsivfycsxhac|ln',
       b'ctvvzeqlzntiprgvpbjakpdvxx|ei?fvjyyoxwxvaqkfvyipq|sshiy|?ermiske gmujrpyrfudtn??r fkmsksmr|utiunnlb?frdrtvjsaxipfaxsdhac?lm',
       b'chfyoyqppbzkhrg psejta?waiu| hzhohowwxwsylqgatr qc|rshtvzrvb??jfbkkzbbnrj zbni?acxv hn|rylziqhiuqklkesmdrfsbshxlmfimqhhajcpl',
   

In [None]:

new_chef = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/managed')
for batch_idx, (x, y) in enumerate(imitation_ds.as_numpy_iterator()):
    y = jax.nn.one_hot(y, len(itos.get_vocabulary()))
    new_chef, (loss, logits) = train_step_jitted(new_chef, x, y, dropout_rng=dropout_rng)

In [190]:
logits.shape, y.shape


((32, 128, 30), (32, 128, 30))

In [191]:
rng = jax.random.PRNGKey(0)
for i in range(10000):
    rng, dropout_rng = jax.random.split(rng)
    chef, (loss, logits) = train_step_jitted(chef, x, y, dropout_rng=dropout_rng)
    if i % 50 == 0:
        print(loss)
        print(jnp.argmax(jnp.squeeze(logits), axis=-1), jnp.argmax(jnp.squeeze(y), axis=-1))


3.4373765
[[21 12  9 ...  1 10 19]
 [14 27  9 ...  1 10  5]
 [14 10  9 ...  1 27  9]
 ...
 [25  2  9 ... 24  2 16]
 [ 6 10 26 ...  1 10 19]
 [25 27  9 ...  1 29 25]] [[23  2 11 ... 19 11  8]
 [ 5  3 11 ...  4  2  6]
 [11  2  8 ... 20  6 23]
 ...
 [ 2  6  4 ... 14  4  7]
 [10  2  4 ... 13 12  3]
 [23  3  2 ...  3 10  2]]
2.0332406
[[23  8  2 ... 19 12 17]
 [ 5  2  9 ...  4  9  7]
 [ 2  2  2 ...  7  6  7]
 ...
 [ 8  5  2 ...  5 15  8]
 [ 4  8  2 ... 23 12  7]
 [ 4  3  2 ... 20  9  8]] [[23  2 11 ... 19 11  8]
 [ 5  3 11 ...  4  2  6]
 [11  2  8 ... 20  6 23]
 ...
 [ 2  6  4 ... 14  4  7]
 [10  2  4 ... 13 12  3]
 [23  3  2 ...  3 10  2]]
0.12781918
[[23  2 11 ... 19 11  8]
 [ 5  3 11 ...  4  2  6]
 [11  2  8 ... 20  6 23]
 ...
 [ 2  6  4 ... 14  4  7]
 [10  2 13 ... 13 12  3]
 [23  3  2 ...  3 10  2]] [[23  2 11 ... 19 11  8]
 [ 5  3 11 ...  4  2  6]
 [11  2  8 ... 20  6 23]
 ...
 [ 2  6  4 ... 14  4  7]
 [10  2  4 ... 13 12  3]
 [23  3  2 ...  3 10  2]]
0.018912686
[[23  2 11 ... 19 11 

KeyboardInterrupt: 

In [197]:
print(jnp.argmax(jnp.squeeze(logits), axis=-1)[:5] == jnp.argmax(jnp.squeeze(y), axis=-1)[:5])

[[ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True

In [192]:
recipe_reader.decode(x, itos)[0]


<tf.Tensor: shape=(), dtype=string, numpy=b'ok slowly stirring occasionally  minutes or until meat is tender serve over rice or buttered noodles yields  to  servings'>

In [39]:
recipe_reader.decode(jnp.argmax(jnp.squeeze(logits), axis=-1), itos)

<tf.Tensor: shape=(32,), dtype=string, numpy=
array([b'ptee and depfel and feape pn  bottteedsp inn basa sot with ranoining  totlp ons of punge  bake at  for  houeitle n',
       b'r cleaniand glice berrie wirh aid sric  atplesade ve fperigg on appr s doolappras stice munanau wasp and cot gripes in h',
       b'n sruterp rroed and cat wth bi chnt ou trr  cirotl mross oels  ith perted org ang soed redy be ce call get miorir aree  ',
       b'atr ctaens cieed witatb bbooe and pay untar onrn shace aried shil en pankstin ti c ravl boll stir in botrecoe sooceyamsecina ',
       b' ponstof  attar frkendt  for  hourtitle ltak dapeplar sterk instructionss roel sfore strurs insblfar crud  indtoelran',
       b'hernitel butte firt aiof mextorl oo booll tin  and nolar mix welr ade rhabeob in r an bea soucei o sea e   ro inch pana',
       b'utle coovee   ivg f to tsppek hip bagun ex uf ooy intt phociwate chip ald shrasfon tox are to syeenthec g dnn al fndkeyesgv',
       b'tibnsr mnx ingrepients 

In [141]:
recipe_reader.decode(
jnp.argmax(jnp.squeeze(logits), axis=-1).reshape(-1, 1),
itos
)

recipe_reader.decode(jnp.argmax(y, axis=-1), itos)

<tf.Tensor: shape=(32,), dtype=string, numpy=
array([b'y', b' ', b' ', b't', b't', b'i', b'd', b'o', b'e', b'm', b'i',
       b'l', b'e', b'h', b'n', b's', b't', b'l', b'h', b's', b'i', b' ',
       b'n', b'f', b' ', b'd', b's', b'o', b'f', b'r', b'h', b'o'],
      dtype=object)>

In [79]:
jnp.argmax(jnp.squeeze(logits), axis=-1), y[:, -1]

(Array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 2, 2, 2, 2, 2], dtype=int32),
 array([10, 20, 10,  9,  9,  2,  5, 17,  3,  7, 11,  8,  4, 22, 10, 12, 10,
        20, 19, 26,  2,  2, 10,  6,  6, 12,  3, 18, 22,  7,  6,  3]))

Array([[[-2.6592684 , -1.6256614 , -1.4934324 , ..., -1.4009198 ,
         -1.1144856 , -1.5977932 ]],

       [[-2.0052314 , -0.45797104, -0.5283097 , ..., -2.510156  ,
         -0.22983934, -0.13950777]],

       [[-0.57558966, -1.0191025 , -2.2016664 , ..., -3.0531387 ,
         -0.3830731 , -1.412798  ]],

       ...,

       [[-0.57177866, -0.13992631, -2.1057372 , ..., -2.6731741 ,
         -1.4790828 , -3.0105035 ]],

       [[-1.2264494 , -1.0429674 , -0.26391634, ..., -3.3369868 ,
         -0.5846946 , -2.311417  ]],

       [[-2.279438  ,  0.06160758, -1.3319504 , ..., -2.0458322 ,
         -0.41904426, -1.8828329 ]]], dtype=float32)