<a href="https://colab.research.google.com/github/romlingroup/flatpack-ai/blob/main/notebooks/flatpack_ai_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# flatpack.ai - RNN
This Colab notebook demonstrates how to train a Recurrent Neural Network (RNN) on a text dataset and then use the trained model to generate new text sequences. The model is trained character by character, learning the patterns and structure of the input text, and can then generate coherent and contextually relevant text sequences.

https://github.com/romlingroup/flatpack-ai

In [None]:
!pip uninstall flatpack -y
!pip install flatpack

In [None]:
from flatpack import datasets, instructions, models

# Download the text and create a character set and indexed text
text_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = datasets.download_text(text_url)[:10000]
chars = sorted(set(text))
indexed_text = [chars.index(char) for char in text]

# Create char_to_index and index_to_char mappings
char_to_index = {char: i for i, char in enumerate(chars)}
index_to_char = {i: char for i, char in enumerate(chars)}

# Initialize the custom dataset and train the RNN model
instructions.build(
    user_train_function=lambda epochs, batch_size: models.RNNLM.train_model(
        dataset=datasets.TextDataset(indexed_text, seq_length=64),
        vocab_size=len(chars),
        embed_size=32,
        hidden_size=128,
        num_layers=4,
        epochs=epochs,
        batch_size=batch_size
    ),
    save_dir='/content/saved_model',
    char_to_index=char_to_index,
    index_to_char=index_to_char,
    batch_size=128,
    epochs=100,
    framework='pytorch',
    model_type='rnn'
)

In [None]:
from flatpack import models

EMBED_SIZE = 32
HIDDEN_SIZE = 128
NUM_LAYERS = 4
SAVE_DIR = '/content/saved_model'
MODEL_PATH = f'{SAVE_DIR}/rnn_model.pth'
GENERATE_LENGTH = 1024
TEMPERATURE = 1.0

model = models.RNN(EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS)
model.load_vocab_size(SAVE_DIR)
model.load_state_dict(models.RNN.load_torch_model(MODEL_PATH))
generated_text = model.generate_text(SAVE_DIR, start_sequence="To be, or not to be", generate_length=GENERATE_LENGTH, temperature=TEMPERATURE)

print("Generated text:", generated_text)