In [None]:
# default_exp model.rnn

# RNN

> API details. @nathan

In [1]:
# export
import pandas as pd
import tensorflow as tf

from pathlib import Path

In [2]:
# hide
# Setup
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

df_fake = pd.DataFrame(
    ["aaaa(bb(aaaa(bb()()ccc)dd)()ccc)dd", "aaaa(bb()ccccc)dd"], columns=["code"]
)

In [None]:
# export
def _loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(
        labels, logits, from_logits=True
    )

In [42]:
# export
class GRUModel(tf.keras.Model):
    def __init__(
        self, n_layers, vocab_size, embedding_dim, rnn_units, batch_size, out_path
    ):
        super(GRUModel, self).__init__()
        gru_layers = [
            tf.keras.layers.GRU(
                rnn_units,
                return_sequences=True,
                # I think we need to have this not be stateful since we don't
                # chop up examples
                # stateful=True,
                recurrent_initializer="glorot_uniform",
                # following BigCode != Big Vocab Paper
                dropout=0.1,
            )
            for _ in range(n_layers)
        ]
        self.model = tf.keras.Sequential(
            [
                tf.keras.layers.Embedding(
                    input_dim=vocab_size,
                    output_dim=embedding_dim,
                    mask_zero=True,  # Zero cannot be used in the vocabulary
                    batch_input_shape=[batch_size, None],
                ),
            ]
            + gru_layers
            + [
                tf.keras.layers.Dense(vocab_size),
            ]
        )

        self.config_name = (
            f"gru_vocab{vocab_size}_embed{embedding_dim}_units{rnn_units}"
        )
        self.out_path = Path(out_path) / self.config_name
        checkpoint_prefix = self.out_path / self.config_name / "ckpt_{epoch}"
        self.callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=checkpoint_prefix, save_weights_only=True
            )
        ]

    def call(self, batch):
        return self.model(batch)

    # TODO add code to easily train model
    def train(self, dataset, epochs):
        self.compile(optimizer="adam", loss=_loss)
        _ = self.model.fit(dataset, epochs=epochs, callbacks=self.callbacks)

In [51]:
# Tokenize the data
tokenized_mthds = [
    tokenizer.encode(mthd, max_length=32, padding="max_length")
    for mthd in df_fake.code.values
]
ds = tf.data.Dataset.from_tensor_slices(tokenized_mthds).batch(2, drop_remainder=True)

In [52]:
gru = GRUModel(1, len(tokenizer), 256, 1024, 1)
gru.model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_2 (Embedding)      (1, None, 256)            12866048  
_________________________________________________________________
gru_2 (GRU)                  (1, None, 1024)           3938304   
_________________________________________________________________
dense_2 (Dense)              (1, None, 50258)          51514450  
Total params: 68,318,802
Trainable params: 68,318,802
Non-trainable params: 0
_________________________________________________________________


In [55]:
for example in ds.take(1):
    output = gru(example)

In [56]:
output

<tf.Tensor: shape=(2, 32, 50258), dtype=float32, numpy=
array([[[ 1.1789987e-03, -1.0764061e-03, -7.0536515e-04, ...,
         -1.8938105e-03,  3.0941013e-04,  5.4561184e-04],
        [ 6.6408433e-04, -3.6950165e-04,  2.1222243e-03, ...,
         -1.6042381e-03,  6.4406340e-05, -3.4853775e-04],
        [ 1.7635215e-03, -2.3366189e-04,  2.0413476e-03, ...,
          1.0840559e-03, -8.7091571e-04,  2.6355719e-03],
        ...,
        [-2.8481015e-03,  1.8251555e-03,  1.5210311e-03, ...,
          1.9312165e-03,  1.0953154e-03,  1.0994463e-03],
        [-2.8373816e-03,  1.8310322e-03,  1.5131955e-03, ...,
          1.9235963e-03,  1.1112771e-03,  1.1064576e-03],
        [-2.8300122e-03,  1.8353187e-03,  1.5077242e-03, ...,
          1.9177452e-03,  1.1218428e-03,  1.1099423e-03]],

       [[ 1.1789987e-03, -1.0764061e-03, -7.0536515e-04, ...,
         -1.8938105e-03,  3.0941013e-04,  5.4561184e-04],
        [ 6.6408433e-04, -3.6950165e-04,  2.1222243e-03, ...,
         -1.6042381e-03,  6

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()