Embedded Reber grammars were used by Hochreiter and Schmidhuber in their paper about LSTMs. They are artificial grammars that produce strings such as “BPBTSXXVPSEPE”. Check out Jenny Orr’s nice introduction to this topic, then choose a particular embedded Reber grammar (such as the one represented on Orr’s page), then train an RNN to identify whether a string respects that grammar or not. You will first need to write a function capable of generating a training batch containing about 50% strings that respect the grammar, and 50% that don’t.

First i wrote the algorithm using three classes to represent the graph and the random next step choice. Without the embedded option, just simple phrases. And the method to validate algorithmically a sequence. Now i have the capability of generating valid sentences. Lets try to generate wrong ones randomly, but in order to the model to learn all the rules i must include the right mistakes in the sentences such as:

- Wrong char: Switch a random wrong char in a random position
- Invalid Transition: In order to the model to understand the state a wrong transition can be generated 
- Incomplete sequences: The model must acknoledge the correct ending of the sequence
- Extra chars: The model must recognize a wrong end of sequence
- Permutation: Change the order of correct sub-sequences making the model understand the order logic

I have a background in object oriented programming so i made it in classes. Now that i can create wrong and right sentences. I must prepare a dataset to train my model.

I was lazy so i asked the ia if it could generate me a class to generate a dataset leveraging TensorFlow Dataset API, it made a quite good job, maybe too much for this experiment.

In [42]:
#Create dataset classes

from random import randint, choice


class ReberGraph:
    def __init__(self, embedded=False):
        self.embedded = embedded
        self.result_string = ''
        self.initial_node = ReberNode()
        second_node = ReberNode()
        third_node = ReberNode()
        fourth_node = ReberNode()
        fifth_node = ReberNode()
        sixth_node = ReberNode()
        seventh_node = ReberNode()
        final_node = ReberNode(is_terminal=True)

        self.initial_node.add_connection(ReberConnection(self.initial_node, second_node, 'B'))
        
        second_node.add_connection(ReberConnection(second_node, third_node, 'T'))        
        second_node.add_connection(ReberConnection(second_node, fourth_node, 'P'))
        
        third_node.add_connection(ReberConnection(third_node, third_node, 'S'))
        third_node.add_connection(ReberConnection(third_node, fifth_node, 'X'))

        fourth_node.add_connection(ReberConnection(fourth_node, sixth_node, 'V'))
        fourth_node.add_connection(ReberConnection(fourth_node, fourth_node, 'T'))

        fifth_node.add_connection(ReberConnection(fifth_node, fourth_node, 'X'))
        fifth_node.add_connection(ReberConnection(fifth_node, seventh_node, 'S'))

        sixth_node.add_connection(ReberConnection(sixth_node, fifth_node, 'V'))
        sixth_node.add_connection(ReberConnection(sixth_node, seventh_node, 'P'))

        seventh_node.add_connection(ReberConnection(seventh_node, final_node, 'E'))
        
        self.error_strategies = []

    def set_error_strategies(self, strategies):
        self.error_strategies = strategies

    def generate_sequence(self):
        sequence = ''
        current_node = self.initial_node
        while not current_node.is_terminal:
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
        
        return sequence   

    def validate_sequence(self, sequence):
        current_node = self.initial_node
        for char in sequence:
            found = False
            for connection in current_node.connections:
                if connection.label == char:
                    current_node = connection.node_to
                    found = True
                    break
            if not found:
                return False
        return current_node.is_terminal

    def generate_wrong_sequence(self):
        selected_strategy = choice(self.error_strategies)
        return selected_strategy.generate_error(self)

        
class ReberConnection:
    def __init__(self, node_from, node_to, label):
        self.node_from = node_from
        self.node_to = node_to
        self.label = label

class ReberNode:
    def __init__(self, is_terminal=False):
        self.is_terminal = is_terminal
        self.connections = []

    def add_connection(self, connection):
        self.connections.append(connection)

    


In [43]:
reber_graph = ReberGraph()
sequence = reber_graph.generate_sequence()
print(sequence)
print(reber_graph.validate_sequence(sequence))


BPVVSE
True


In [None]:
# Error Strategy Interface and Implementations

from abc import ABC, abstractmethod
from random import randint, choice


class SequenceErrorStrategy(ABC):
    
    @abstractmethod
    def generate_error(self, graph):
        pass


class WrongCharError(SequenceErrorStrategy):    
    
    def generate_error(self, graph):
        sequence = graph.generate_sequence()
        if len(sequence) == 0:
            return sequence
        
        invalid_chars = ['Z', 'Q', 'W', 'R', 'Y', 'U', 'I', 'O', 'A', 'C', 'D', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N']
        random_pos = randint(0, len(sequence) - 1)
        wrong_char = choice(invalid_chars)
        
        return sequence[:random_pos] + wrong_char + sequence[random_pos + 1:]


class InvalidTransitionError(SequenceErrorStrategy):

    def generate_error(self, graph):
        sequence = ''
        current_node = graph.initial_node
        valid_labels = {'B', 'T', 'P', 'S', 'X', 'V', 'E'}
        
        steps_before_error = randint(1, 3)
        step_count = 0
        
        while not current_node.is_terminal and len(sequence) < 15:
            step_count += 1
            
            if step_count == steps_before_error:
                # Get valid labels that are NOT in the current connections
                available_labels = {c.label for c in current_node.connections}
                invalid_labels = list(valid_labels - available_labels)
                
                if invalid_labels:
                    # Insert a valid character but not allowed in this state
                    wrong_label = choice(invalid_labels)
                    sequence += wrong_label
                    break  # Terminate because it's already invalid
                else:
                    # If all labels are valid, use a random one from the available ones
                    selected_connection_index = randint(0, len(current_node.connections) - 1)
                    sequence += current_node.connections[selected_connection_index].label
                    current_node = current_node.connections[selected_connection_index].node_to
            else:
                # Continue normally
                selected_connection_index = randint(0, len(current_node.connections) - 1)
                sequence += current_node.connections[selected_connection_index].label
                current_node = current_node.connections[selected_connection_index].node_to
        
        return sequence


class IncompleteSequenceError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = ''
        current_node = graph.initial_node
        # Stop after a random number of steps (but before completing)
        max_steps = randint(2, 6)
        step_count = 0
        
        while not current_node.is_terminal and step_count < max_steps:
            if len(current_node.connections) == 0:
                break
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
            step_count += 1
        
        # The sequence does not end in 'E', therefore it's invalid
        return sequence


class ExtraCharsError(SequenceErrorStrategy):
    """Strategy: Add extra characters after the valid end of the sequence"""
    
    def generate_error(self, graph):
        sequence = graph.generate_sequence()
        # Add extra characters after the final 'E'
        extra_chars = ['X', 'S', 'T', 'P', 'V']
        num_extra = randint(1, 4)
        
        for _ in range(num_extra):
            sequence += choice(extra_chars)
        
        return sequence


class PermutationError(SequenceErrorStrategy):    
    def generate_error(self, graph):
        sequence = list(graph.generate_sequence())
        if len(sequence) >= 2:
            pos = randint(0, len(sequence) - 2)
            sequence[pos], sequence[pos + 1] = sequence[pos + 1], sequence[pos]
        return ''.join(sequence)


class WrongStartError(SequenceErrorStrategy):    
    def generate_error(self, graph):
        invalid_starts = ['T', 'P', 'S', 'X', 'V', 'E']
        sequence = choice(invalid_starts)
        
        current_node = graph.initial_node
        while not current_node.is_terminal:
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
        
        return sequence



In [45]:
# Usage example: Configure and test error strategies

error_strategies = [
    WrongCharError(),
    InvalidTransitionError(),
    IncompleteSequenceError(),
    ExtraCharsError(),
    PermutationError(),
    WrongStartError()
]

reber_graph.set_error_strategies(error_strategies)

print("=== Valid sequences ===")
for i in range(3):
    valid_seq = reber_graph.generate_sequence()
    print(f"Valid {i+1}: {valid_seq} (validation: {reber_graph.validate_sequence(valid_seq)})")

print("\n=== Invalid sequences ===")
for i in range(5):
    wrong_seq = reber_graph.generate_wrong_sequence()
    print(f"Invalid {i+1}: {wrong_seq} (validation: {reber_graph.validate_sequence(wrong_seq)})")



=== Valid sequences ===
Valid 1: BPTVPE (validation: True)
Valid 2: BTXXVPE (validation: True)
Valid 3: BPTTTTTTTTVVSE (validation: True)

=== Invalid sequences ===
Invalid 1: TBXSE (validation: False)
Invalid 2: BPVVVXPE (validation: False)
Invalid 3: BS (validation: False)
Invalid 4: BPE (validation: False)
Invalid 5: BTSDSE (validation: False)


In [46]:
import tensorflow as tf
import numpy as np
from typing import Tuple, Optional


class ReberDataset:
    def __init__(self, graph: ReberGraph, vocab: Optional[dict] = None):
        """
        Initialize the ReberDataset generator.
        
        Args:
            graph: ReberGraph instance for generating sequences
            vocab: Optional vocabulary mapping. If None, will be created from Reber alphabet
        """
        self.graph = graph
        self.reber_alphabet = ['B', 'T', 'P', 'S', 'X', 'V', 'E']
        
        # Create vocabulary mapping: char -> int
        if vocab is None:
            self.vocab = {char: idx + 1 for idx, char in enumerate(self.reber_alphabet)}
            self.vocab['<PAD>'] = 0  # Padding token
        else:
            self.vocab = vocab
        
        # Reverse vocabulary: int -> char
        self.idx_to_char = {idx: char for char, idx in self.vocab.items()}
        self.vocab_size = len(self.vocab)
    
    def _sequence_to_indices(self, sequence: str) -> list:
        """Convert a sequence string to a list of token indices"""
        return [self.vocab.get(char, 0) for char in sequence]
    
    def _generate_sample(self, is_valid: bool) -> Tuple[list, int]:
        """
        Generate a single sample (sequence, label).
        
        Args:
            is_valid: True for valid sequence, False for invalid
            
        Returns:
            Tuple of (tokenized_sequence, label)
        """
        if is_valid:
            sequence = self.graph.generate_sequence()
            label = 1
        else:
            sequence = self.graph.generate_wrong_sequence()
            label = 0
        
        tokenized = self._sequence_to_indices(sequence)
        return tokenized, label
    
    def _generator(self, num_samples: int, valid_ratio: float = 0.5):
        """
        Generator function for creating sequences on-the-fly.
        This enables lazy loading and memory efficiency.
        
        Args:
            num_samples: Total number of samples to generate
            valid_ratio: Ratio of valid sequences (default 0.5 for 50/50 split)
        """
        num_valid = int(num_samples * valid_ratio)
        num_invalid = num_samples - num_valid
        
        # Generate valid sequences
        for _ in range(num_valid):
            tokenized, label = self._generate_sample(is_valid=True)
            yield (np.array(tokenized, dtype=np.int32), np.array(label, dtype=np.int32))
        
        # Generate invalid sequences
        for _ in range(num_invalid):
            tokenized, label = self._generate_sample(is_valid=False)
            yield (np.array(tokenized, dtype=np.int32), np.array(label, dtype=np.int32))
    
    def generate_dataset(
        self,
        num_samples: int = 10000,
        batch_size: int = 32,
        valid_ratio: float = 0.5,
        max_length: Optional[int] = None,
        shuffle: bool = True,
        shuffle_buffer_size: int = 1000,
        prefetch: bool = True,
        prefetch_size: int = tf.data.AUTOTUNE,
        cache: bool = False,
        repeat: bool = False,
        num_parallel_calls: int = tf.data.AUTOTUNE
    ) -> tf.data.Dataset:
        """
        Generate a TensorFlow Dataset with all optimizations.
        
        Args:
            num_samples: Total number of samples to generate
            batch_size: Batch size for training
            valid_ratio: Ratio of valid sequences (default 0.5)
            max_length: Maximum sequence length for padding. If None, uses max in batch
            shuffle: Whether to shuffle the dataset
            shuffle_buffer_size: Buffer size for shuffling
            prefetch: Whether to prefetch batches for better GPU utilization
            prefetch_size: Number of batches to prefetch (AUTOTUNE recommended)
            cache: Whether to cache the dataset in memory (useful for small datasets)
            repeat: Whether to repeat the dataset indefinitely (for training loops)
            num_parallel_calls: Number of parallel calls for map operations
            
        Returns:
            tf.data.Dataset ready for training
        """
        # Create dataset from generator
        dataset = tf.data.Dataset.from_generator(
            lambda: self._generator(num_samples, valid_ratio),
            output_signature=(
                tf.TensorSpec(shape=(None,), dtype=tf.int32),  # Variable length sequences
                tf.TensorSpec(shape=(), dtype=tf.int32)  # Labels
            )
        )
        
        # Shuffle if requested
        if shuffle:
            dataset = dataset.shuffle(
                buffer_size=shuffle_buffer_size,
                reshuffle_each_iteration=True
            )
        
        # Pad sequences to the same length within each batch
        # This is more efficient than padding to a fixed max_length
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=(
                [max_length] if max_length else [None],  # Pad sequences
                []  # Labels don't need padding
            ),
            padding_values=(
                self.vocab['<PAD>'],  # Padding value for sequences
                0  # Padding value for labels (won't be used)
            ),
            drop_remainder=False  # Keep last incomplete batch
        )
        
        # Cache if requested (useful for small datasets that fit in memory)
        if cache:
            dataset = dataset.cache()
        
        # Prefetch for better GPU utilization
        if prefetch:
            dataset = dataset.prefetch(prefetch_size)
        
        # Repeat if requested (for training loops)
        if repeat:
            dataset = dataset.repeat()
        
        return dataset
    
    def generate_train_val_split(
        self,
        num_samples: int = 10000,
        val_ratio: float = 0.2,
        batch_size: int = 32,
        **kwargs
    ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """
        Generate training and validation datasets with proper splitting.
        
        Args:
            num_samples: Total number of samples
            val_ratio: Ratio of validation samples
            batch_size: Batch size
            **kwargs: Additional arguments passed to generate_dataset
            
        Returns:
            Tuple of (train_dataset, val_dataset)
        """
        num_train = int(num_samples * (1 - val_ratio))
        num_val = num_samples - num_train
        
        train_dataset = self.generate_dataset(
            num_samples=num_train,
            batch_size=batch_size,
            **kwargs
        )
        
        val_dataset = self.generate_dataset(
            num_samples=num_val,
            batch_size=batch_size,
            shuffle=False,  # Don't shuffle validation set
            cache=True,  # Cache validation set for faster evaluation
            **{k: v for k, v in kwargs.items() if k not in ['shuffle', 'cache']}
        )
        
        return train_dataset, val_dataset


In [47]:
# Example: Using the TensorFlow Dataset

# Initialize the dataset generator
dataset_generator = ReberDataset(reber_graph)

# Generate a dataset with all optimizations
train_dataset = dataset_generator.generate_dataset(
    num_samples=10000,
    batch_size=32,
    valid_ratio=0.5,
    shuffle=True,
    prefetch=True,
    cache=False  # Set to True if dataset fits in memory
)

# Or generate train/val split directly
train_ds, val_ds = dataset_generator.generate_train_val_split(
    num_samples=10000,
    val_ratio=0.2,
    batch_size=32
)

# Inspect the dataset
print("Dataset structure:")
print(f"Vocabulary size: {dataset_generator.vocab_size}")
print(f"Vocabulary: {dataset_generator.vocab}")
print("\nSample batch:")
for sequences, labels in train_dataset.take(1):
    print(f"Sequences shape: {sequences.shape}")  # (batch_size, max_seq_length)
    print(f"Labels shape: {labels.shape}")      # (batch_size,)
    print(f"First sequence: {sequences[0].numpy()}")
    print(f"First label: {labels[0].numpy()}")
    print(f"Decoded sequence: {''.join([dataset_generator.idx_to_char.get(int(idx), '?') for idx in sequences[0].numpy() if idx != 0])}")


Dataset structure:
Vocabulary size: 8
Vocabulary: {'B': 1, 'T': 2, 'P': 3, 'S': 4, 'X': 5, 'V': 6, 'E': 7, '<PAD>': 0}

Sample batch:
Sequences shape: (32, 22)
Labels shape: (32,)
First sequence: [1 2 4 4 5 4 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
First label: 1
Decoded sequence: BTSSXSE


In [None]:
# Example: Using with Keras Model (commented out - uncomment to use)

# # Create train/validation datasets
# train_ds, val_ds = dataset_generator.generate_train_val_split(
#     num_samples=10000,
#     val_ratio=0.2,
#     batch_size=32
# )

# # Example model architecture
# model = tf.keras.Sequential([
#     tf.keras.layers.Embedding(
#         input_dim=dataset_generator.vocab_size,
#         output_dim=64,
#         mask_zero=True  # Automatically mask padding tokens
#     ),
#     tf.keras.layers.LSTM(64, return_sequences=False),
#     tf.keras.layers.Dense(32, activation='relu'),
#     tf.keras.layers.Dense(1, activation='sigmoid')  # Binary classification
# ])

# model.compile(
#     optimizer='adam',
#     loss='binary_crossentropy',
#     metrics=['accuracy']
# )

# # Train the model
# # The dataset is already batched, shuffled, and prefetched!
# history = model.fit(
#     train_ds,
#     validation_data=val_ds,
#     epochs=10,
#     verbose=1
# )

print("Dataset ready for training! Uncomment the code above to train a model.")
