# LSTM for Text Generation

In this notebook, we'll build an **LSTM-based text generator**. LSTMs are excellent for sequence tasks because they can remember long-term dependencies.

We'll use a simple text dataset to demonstrate how an LSTM can learn patterns in text and generate new sequences.

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.utils import to_categorical

print("TensorFlow version:", tf.__version__)

## Prepare Data
We’ll use a small text string to create a character-level dataset.

In [None]:
text = "hello world this is an lstm text generator"
chars = sorted(list(set(text)))
char_to_idx = {c:i for i,c in enumerate(chars)}
idx_to_char = {i:c for i,c in enumerate(chars)}

seq_length = 10
sequences = []
next_chars = []

for i in range(len(text) - seq_length):
    sequences.append(text[i:i+seq_length])
    next_chars.append(text[i+seq_length])

X = np.array([[char_to_idx[c] for c in seq] for seq in sequences])
y = np.array([char_to_idx[c] for c in next_chars])

print("Vocabulary size:", len(chars))
print("Number of sequences:", len(sequences))

## Build LSTM Model

In [None]:
model = Sequential([
    Embedding(input_dim=len(chars), output_dim=50, input_length=seq_length),
    LSTM(128, return_sequences=False),
    Dense(len(chars), activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

## Train Model

In [None]:
history = model.fit(X, y, epochs=50, verbose=1)

## Generate Text
We’ll now use the trained LSTM to generate text by predicting one character at a time.

In [None]:
def generate_text(seed_text, length=50):
    generated = seed_text
    for _ in range(length):
        x_pred = np.array([[char_to_idx[c] for c in generated[-seq_length:]]])
        preds = model.predict(x_pred, verbose=0)[0]
        next_idx = np.argmax(preds)
        next_char = idx_to_char[next_idx]
        generated += next_char
    return generated

print(generate_text("hello worl"))