In [1]:
import flax
import flax.linen as nn

import optax 

import jax
import jax.numpy as jnp

from model import SimpleRNN

from FlaxTrainer.trainer import TrainerModule
from FlaxTrainer.trainstates import TrainState

import transformers
from datasets import load_dataset_builder, load_dataset, get_dataset_split_names
import re


In [2]:
get_dataset_split_names('imdb')

['train', 'test', 'unsupervised']

In [3]:
from typing import List


train_ds = load_dataset('imdb', split='train')
test_ds = load_dataset('imdb', split='test')

full_text = " ".join(train_ds['text'])

# full_text = re.sub("^[A-Za-z0-9_-]*$", '', full_text)
# full_text = full_text.replace('<br',"").replace('/>',"")
# vocabs = set(full_text.split(' '))
# vocabs = (list(vocabs))

# def tokenizing(vocab: List[str]):
#     id_to_text = {i:v for i, v in enumerate(vocab)}
#     text_to_id = {v:i for i, v in enumerate(vocab)}

#     return id_to_text, text_to_id

# id_to_text, text_to_id = tokenizing(vocabs)



In [4]:
import nltk

seperator = lambda x: nltk.regexp_tokenize(x, "\w+|\\?\!")

vocabs = list(set(seperator(full_text)))

def tokenizing(vocab: List[str]):
    id_to_text = {i:v for i, v in enumerate(vocab)}
    text_to_id = {v:i for i, v in enumerate(vocab)}

    return id_to_text, text_to_id

id_to_text, text_to_id = tokenizing(vocabs)

In [5]:
len(id_to_text)

93930

In [6]:
def scentence_tokenized(sent):
    word_seperation = seperator(sent)
    return [text_to_id[w] for w in word_seperation]

In [7]:
from typing import Any


class RnnClassifier(nn.Module):
    rng: Any
    embed_size: int
    embed_dim: int
    hidden_size: int

    @nn.compact
    def __call__(self, h, x, **kwargs):
        x = nn.Embed(self.embed_size, self.embed_dim)(x)
        o, h = SimpleRNN(20, 2)(x, h)

        return o, h
    @staticmethod
    def initialize_carry(
        rng, 
        batch_dims, 
        size,
        init_fn=nn.initializers.zeros
    ):
        return SimpleRNN.initialize_carry(rng, batch_dims, size, init_fn)

In [8]:
rnnC = RnnClassifier(
    jax.random.PRNGKey(0),
    embed_size=len(text_to_id),
    embed_dim=20,
    hidden_size=10
)



In [9]:
h = RnnClassifier.initialize_carry(
    jax.random.PRNGKey(0),
    (1, ),
    20
)
x = jax.random.randint(
    jax.random.PRNGKey(0),
    (1,),
    minval=1, maxval=20
)
h

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.]], dtype=float32)

In [10]:
variables = rnnC.init(jax.random.PRNGKey(0), h, x)

In [11]:
jax.tree_map(lambda x: x.shape , variables)

FrozenDict({
    params: {
        Embed_0: {
            embedding: (93930, 20),
        },
        SimpleRNN_0: {
            hidden_state: {
                bias: (20,),
                kernel: (40, 20),
            },
            output: {
                bias: (2,),
                kernel: (40, 2),
            },
        },
    },
})

In [12]:
rnnC.apply(variables=variables, x=x, h=h)

(DeviceArray([[-0.02662284,  0.39234704]], dtype=float32),
 DeviceArray([[0.        , 0.        , 0.36096382, 0.10458659, 0.        ,
               0.        , 0.        , 0.        , 0.        , 0.        ,
               0.28245187, 0.18464863, 0.02286166, 0.04945395, 0.        ,
               0.00896228, 0.21438086, 0.        , 0.04968444, 0.        ]],            dtype=float32))

In [13]:
example = jnp.array(scentence_tokenized(train_ds['text'][0]))
example.shape

(297,)

In [14]:
h

DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.]], dtype=float32)

In [15]:
from time import time


t = time()
for x in example.reshape(-1, 1):
     o, h = rnnC.apply(variables, h, x)
print(o, h)

print(time() - t)

[[ 0.101531   -0.03670267]] [[0.         0.33540764 0.27758348 0.         0.27738208 0.
  0.         0.03469883 0.59443516 0.11394904 0.         0.26410994
  0.2550305  0.         0.         0.21341687 0.10182323 0.
  0.14639407 0.        ]]
13.508159160614014


In [16]:
class RnnTrainState(TrainState):
    hidden_fn: Any=None


class RnnTrainer(TrainerModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def init_model(
        self, model: nn.Module, 
        exmp_input: Any, 
        tabulated: bool = True):
        state =  super().init_model(model, exmp_input, tabulated)
        st_dict = state.__dict__
        
        state = RnnTrainState(
            **{
                **st_dict, "hidden_fn":model.init_state()
            }
        )
        return state
    
        
    def create_functions(self):
        def cross_entropy_loss(params, apply_fn, batch):
            x, y = batch
            y = jax.nn.one_hot(y, num_classes=10)
            logit = apply_fn({'params':params}, x)
            loss = optax.softmax_cross_entropy(logits=logit , labels=y).mean()
            return loss
    
        def train_step(state, batch):
            loss_fn = lambda params: cross_entropy_loss(params, state.apply_fn, batch)
            for word in 
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grads)
            metrics = {'loss': loss}
            return state, metrics
        
        def eval_step(state, batch):
            loss = cross_entropy_loss(state.params, state.apply_fn, batch)
            return {'loss': loss}

        return train_step, eval_step


SyntaxError: invalid syntax (2885781288.py, line 34)

In [16]:

def fn(i, ho):
    x = example.reshape(-1, 1)[i]
    h = ho[0]
    o, h = rnnC.apply(variables=variables, x=x, h=h)
    return [h, o]
    

In [17]:
h = RnnClassifier.initialize_carry(
    jax.random.PRNGKey(0),
    (1, ),
    20
)

ho = [h, jnp.zeros((1,2))]
fn(1, ho)

[DeviceArray([[0.00241531, 0.51561534, 0.3784473 , 0.        , 0.24155109,
               0.        , 0.02114906, 0.0136361 , 0.21647874, 0.00569096,
               0.        , 0.0266905 , 0.1747001 , 0.08343177, 0.19916058,
               0.20462283, 0.        , 0.16992225, 0.        , 0.        ]],            dtype=float32),
 DeviceArray([[-0.13409576,  0.09588944]], dtype=float32)]

In [18]:
jax.lax.fori_loop(0, len(example), fn,ho)

[DeviceArray([[0.        , 0.33540764, 0.27758348, 0.        , 0.27738208,
               0.        , 0.        , 0.03469883, 0.59443516, 0.11394904,
               0.        , 0.26410994, 0.2550305 , 0.        , 0.        ,
               0.21341687, 0.10182323, 0.        , 0.14639407, 0.        ]],            dtype=float32),
 DeviceArray([[ 0.101531  , -0.03670267]], dtype=float32)]

In [19]:
l = jax.lax.map(lambda x: jnp.concatenate([h, x.reshape(1,1)], axis=-1), example)

In [20]:
def fn(h, x):
    h = h[...,:-1]
    x = x[0,...,-1].astype(jnp.int32).reshape(-1,)
    o, h = rnnC.apply(variables=variables, x=x, h=h)
    h_r = jnp.concatenate([h, jnp.zeros((1,1))], axis=-1)
    return h_r, (h_r, o)

In [21]:
fn(l[0], l[1])

(DeviceArray([[0.00241531, 0.51561534, 0.3784473 , 0.        , 0.24155109,
               0.        , 0.02114906, 0.0136361 , 0.21647874, 0.00569096,
               0.        , 0.0266905 , 0.1747001 , 0.08343177, 0.19916058,
               0.20462283, 0.        , 0.16992225, 0.        , 0.        ,
               0.        ]], dtype=float32),
 (DeviceArray([[0.00241531, 0.51561534, 0.3784473 , 0.        , 0.24155109,
                0.        , 0.02114906, 0.0136361 , 0.21647874, 0.00569096,
                0.        , 0.0266905 , 0.1747001 , 0.08343177, 0.19916058,
                0.20462283, 0.        , 0.16992225, 0.        , 0.        ,
                0.        ]], dtype=float32),
  DeviceArray([[-0.13409576,  0.09588944]], dtype=float32)))

In [23]:
_, out = jax.lax.scan(fn, jnp.concatenate([h, jnp.zeros((1,1))], axis=1), l)
    

In [307]:
len(out[1])

297

In [25]:
jax.config.values

{'jax_tracer_error_num_traceback_frames': 5,
 'jax_pprint_use_color': True,
 'jax_host_callback_inline': False,
 'jax_host_callback_max_queue_byte_size': 256000000,
 'jax_host_callback_outfeed': False,
 'jax_host_callback_ad_transforms': False,
 'jax2tf_associative_scan_reductions': False,
 'jax2tf_default_experimental_native_lowering': False,
 'jax_platforms': None,
 'jax_enable_checks': False,
 'jax_check_tracer_leaks': False,
 'jax_debug_nans': False,
 'jax_debug_infs': False,
 'jax_log_compiles': False,
 'jax_parallel_functions_output_gda': False,
 'jax_array': False,
 'jax_distributed_debug': False,
 'jax_enable_custom_prng': False,
 'jax_default_prng_impl': 'threefry2x32',
 'jax_enable_custom_vjp_by_custom_transpose': False,
 'jax_hlo_source_file_canonicalization_regex': None,
 'jax_default_dtype_bits': '64',
 'jax_numpy_dtype_promotion': 'standard',
 'jax_enable_x64': False,
 'jax_default_device': None,
 'jax_disable_jit': False,
 'jax_numpy_rank_promotion': 'allow',
 'jax_defau