This project implements text generation models using RNN, LSTM, and Transformer architectures in PyTorch. It trains these models on a small text dataset from Project Gutenberg and evaluates their performance using perplexity and BLEU score metrics
- Modular implementation of RNN, LSTM, and Transformer models
- BPE tokenization using SentencePiece
- Training pipeline with early stopping and learning rate scheduling
- Evaluation using perplexity and BLEU score metrics
- Text generation with temperature sampling
- Visualization of results and model comparisons
- Command-line interface for training, evaluation, and generation
project/
├── config.py # Configuration parameters and hyperparameters
├── main.py # Main entry point script
├── requirements.txt # Project dependencies
├── data/ # Data directory
│ ├── raw/ # Raw text files
│ ├── train.jsonl # Training data
│ └── test.jsonl # Testing data
├── models/ # Directory for saved models
├── plots/ # Directory for output plots
└── src/ # Source code directory
├── data/ # Data handling modules
│ ├── dataset.py # Dataset and dataloader implementations
│ └── tokenizer.py # Tokenizer implementation
├── models/ # Model implementations
│ ├── base_model.py # Base class for all models
│ ├── rnn_model.py # RNN model implementation
│ ├── lstm_model.py # LSTM model implementation
│ └── transformer_model.py # Transformer model implementation
├── training/ # Training modules
│ └── trainer.py # Training loop implementation
├── evaluation/ # Evaluation modules
│ └── metrics.py # Evaluation metrics
└── visualization/ # Visualization modules
└── loss_plots.py # Training/validation loss plots
# create virtual environment
$ conda create -n slm77 python=3.10 -y
# activate environment
$ conda activate slm
# install packages
$ pip install -r requirements.txtThe dataset consists of classic literature texts from Project Gutenberg, including works such as:
- Alice in Wonderland
- Art of War
- Dracula
- Frankenstein
- Great Gatsby
- and more...
The data is organized into:
data/raw/: Original text filestrain.jsonl: Training examples in JSON format with prompt and completion pairstest.jsonl: Testing examples in the same format
To train all three models (RNN, LSTM, and Transformer):
# Train all models
python main.py --train --model_type 0
# Train only the RNN model
python main.py --train --model_type 1
# Train only the LSTM model
python main.py --train --model_type 2
# Train only the Transformer model
python main.py --train --model_type 3This will:
- Load and tokenize the data
- Train a BPE tokenizer with vocabulary size 10,000
- Train each model for up to 30 epochs with early stopping
- Save model checkpoints to the
models/directory - Generate training/validation loss plots
To evaluate the trained models on the test set:
# Evaluate all the models
python main.py --evaluate --model_type 0This will:
- Load the trained models
- Calculate perplexity and BLEU score for each model
- Generate comparison visualizations
- Print a summary of results
To generate text from trained models:
# Generate text using only the Transformer model
python main.py --generate --prompt "Which do you prefer? Dogs or cats?" --model_type 1 --temperature 0.8 --max_length 100Additional arguments:
--temperature: Controls randomness in generation (default: 1.0)--max_length: Maximum number of tokens to generate (default: 100)--seed: Random seed for reproducibility (default: 42)
You can also combine operations:
python main.py --train --evaluate --model_type 0Model hyperparameters and training settings can be modified in the config.py file:
# Model parameters
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2
DROPOUT = 0.2
TRANSFORMER_HEADS = 8
# Training parameters
BATCH_SIZE = 128
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 30
EARLY_STOPPING_PATIENCE = 3
GRADIENT_CLIP_VAL = 1.0
# Generation parameters
TEMPERATURE = 0.8
MAX_GENERATION_LENGTH = 100After training and evaluation, you'll find:
- Trained model checkpoints in the
models/directory - Visualization plots in the
plots/directory:- Training and validation loss curves
- Perplexity comparison
- BLEU score comparison
- Normalized metrics comparison
The modular structure makes it easy to extend this project:
- Add new models: Create a new model file in
src/models/that inherits fromBaseTextGenerationModel - Add new metrics: Implement additional evaluation metrics in
src/evaluation/metrics.py - Try different datasets: Update the data loading and preprocessing in
src/data/
- The dataset is derived from public domain texts from Project Gutenberg
- The implementation is built with PyTorch