<a href="https://colab.research.google.com/github/romlingroup/flatpack-ai/blob/main/notebooks/flatpack_ai_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# flatpack.ai - RNN
This Colab notebook demonstrates how to train a Recurrent Neural Network (RNN) on a text dataset and then use the trained model to generate new text sequences. The model is trained character by character, learning the patterns and structure of the input text, and can then generate coherent and contextually relevant text sequences.

https://github.com/romlingroup/flatpack-ai

In [None]:
!pip uninstall flatpack -y
!pip install flatpack

In [None]:
from flatpack import datasets, instructions, models, utils

text_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
indexed_text, char_to_index, index_to_char = datasets.download_and_preprocess_text(text_url, limit=10000)

device = utils.configure_device()
instructions.build(
    framework='pytorch',
    model_type='rnn',
    epochs=100,
    batch_size=2048,
    char_to_index=char_to_index,
    index_to_char=index_to_char,
    save_dir='/content/saved_model',
    device=device,
    user_train_function=lambda epochs, batch_size: models.RNN.train_model(
        indexed_text=indexed_text,
        vocab_size=len(char_to_index),
        seq_length=64,
        embed_size=256,
        hidden_size=128,
        num_layers=4,
        epochs=epochs,
        batch_size=batch_size,
        device=device,
    )
)

In [None]:
import random
from flatpack import models, utils

SAVE_DIR = '/content/saved_model'
MODEL_PATH = f'{SAVE_DIR}/rnn_model.pth'
EMBED_SIZE = 256
HIDDEN_SIZE = 128
NUM_LAYERS = 4
GENERATE_LENGTH = 512
TEMPERATURE = 1

random.seed(42)

device = utils.configure_device()

model = models.RNN(EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS).to(device)
model.load_vocab_size(SAVE_DIR)
model.load_state_dict(models.RNN.load_torch_model(MODEL_PATH))
model.embedding = model.embedding.to(device)
model.rnn = model.rnn.to(device)
model.fc = model.fc.to(device)

generated_text = model.generate_text(SAVE_DIR, start_sequence="To be, or not to be", generate_length=GENERATE_LENGTH, temperature=TEMPERATURE, device=device)

print("Generated text:", generated_text)