Embedded Reber grammars were used by Hochreiter and Schmidhuber in their paper about LSTMs. They are artificial grammars that produce strings such as “BPBTSXXVPSEPE”. Check out Jenny Orr’s nice introduction to this topic, then choose a particular embedded Reber grammar (such as the one represented on Orr’s page), then train an RNN to identify whether a string respects that grammar or not. You will first need to write a function capable of generating a training batch containing about 50% strings that respect the grammar, and 50% that don’t.

First i wrote the algorithm using three classes to represent the graph and the random next step choice. Without the embedded option, just simple phrases. And the method to validate algorithmically a sequence. Now i have the capability of generating valid sentences. Lets try to generate wrong ones randomly, but in order to the model to learn all the rules i must include the right mistakes in the sentences such as:

- Wrong char: Switch a random wrong char in a random position
- Invalid Transition: In order to the model to understand the state a wrong transition can be generated 
- Incomplete sequences: The model must acknoledge the correct ending of the sequence
- Extra chars: The model must recognize a wrong end of sequence
- Permutation: Change the order of correct sub-sequences making the model understand the order logic

I have a background in object oriented programming so i made it in classes. Now that i can create wrong and right sentences. I must prepare a dataset to train my model.

I was lazy so i asked the ia if it could generate me a class to generate a dataset leveraging TensorFlow Dataset API, it made a quite good job, maybe too much for this experiment.

I was doubing about the use of prefetch, since im not loading anything from files, the generation of sequences is fast, done in memory, the sequences are short, there are no calls to the I/O from disc o heavy processing. I found out that pre-fetch can still be useful to pad and batch the sentences, filling variable length sentences can take some time, and if useing pre-fetch can leverage a GPU if used, while the model trains a batch, prefetch loads the next batched/padded.

While programming i forgot something very important: To set the tensor flow seed. This is important because it allows the developer to really detect valuable changes in the model, since working with always the same randomness.

In [1]:
#Create dataset classes

from random import randint, choice


class ReberGraph:
    def __init__(self, embedded=False):
        self.embedded = embedded
        self.result_string = ''
        self.initial_node = ReberNode()
        second_node = ReberNode()
        third_node = ReberNode()
        fourth_node = ReberNode()
        fifth_node = ReberNode()
        sixth_node = ReberNode()
        seventh_node = ReberNode()
        final_node = ReberNode(is_terminal=True)

        self.initial_node.add_connection(ReberConnection(self.initial_node, second_node, 'B'))
        
        second_node.add_connection(ReberConnection(second_node, third_node, 'T'))        
        second_node.add_connection(ReberConnection(second_node, fourth_node, 'P'))
        
        third_node.add_connection(ReberConnection(third_node, third_node, 'S'))
        third_node.add_connection(ReberConnection(third_node, fifth_node, 'X'))

        fourth_node.add_connection(ReberConnection(fourth_node, sixth_node, 'V'))
        fourth_node.add_connection(ReberConnection(fourth_node, fourth_node, 'T'))

        fifth_node.add_connection(ReberConnection(fifth_node, fourth_node, 'X'))
        fifth_node.add_connection(ReberConnection(fifth_node, seventh_node, 'S'))

        sixth_node.add_connection(ReberConnection(sixth_node, fifth_node, 'V'))
        sixth_node.add_connection(ReberConnection(sixth_node, seventh_node, 'P'))

        seventh_node.add_connection(ReberConnection(seventh_node, final_node, 'E'))
        
        self.error_strategies = []

    def set_error_strategies(self, strategies):
        self.error_strategies = strategies

    def generate_sequence(self):
        sequence = ''
        current_node = self.initial_node
        while not current_node.is_terminal:
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
        
        return sequence   

    def validate_sequence(self, sequence):
        current_node = self.initial_node
        for char in sequence:
            found = False
            for connection in current_node.connections:
                if connection.label == char:
                    current_node = connection.node_to
                    found = True
                    break
            if not found:
                return False
        return current_node.is_terminal

    def generate_wrong_sequence(self):
        selected_strategy = choice(self.error_strategies)
        return selected_strategy.generate_error(self)

        
class ReberConnection:
    def __init__(self, node_from, node_to, label):
        self.node_from = node_from
        self.node_to = node_to
        self.label = label

class ReberNode:
    def __init__(self, is_terminal=False):
        self.is_terminal = is_terminal
        self.connections = []

    def add_connection(self, connection):
        self.connections.append(connection)

    


In [2]:
reber_graph = ReberGraph()
sequence = reber_graph.generate_sequence()
print(sequence)
print(reber_graph.validate_sequence(sequence))


BTXSE
True


In [3]:
from abc import ABC, abstractmethod
from random import randint, choice


class SequenceErrorStrategy(ABC):
    
    @abstractmethod
    def generate_error(self, graph):
        pass


class WrongCharError(SequenceErrorStrategy):    
    
    def generate_error(self, graph):
        sequence = graph.generate_sequence()
        if len(sequence) <= 1:
            return sequence
        
        invalid_chars = ['Z', 'Q', 'W', 'R', 'Y', 'U', 'I', 'O', 'A', 'C', 'D', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N']
        random_pos = randint(1, len(sequence) - 2)
        wrong_char = choice(invalid_chars)
        
        return sequence[:random_pos] + wrong_char + sequence[random_pos + 1:]


class InvalidTransitionError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = ''
        current_node = graph.initial_node
        valid_labels = {'B', 'T', 'P', 'S', 'X', 'V', 'E'}
        
        steps_before_error = randint(1, 4)
        step_count = 0
        
        while not current_node.is_terminal and len(sequence) < 15:
            step_count += 1
            
            if step_count == steps_before_error and len(current_node.connections) > 0:
                available_labels = {c.label for c in current_node.connections}
                invalid_labels = list(valid_labels - available_labels)
                
                if invalid_labels:
                    wrong_label = choice(invalid_labels)
                    sequence += wrong_label
                    return sequence
            
            if len(current_node.connections) == 0:
                break
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
        
        return sequence if not current_node.is_terminal else sequence[:-1]


class IncompleteSequenceError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = ''
        current_node = graph.initial_node
        max_steps = randint(2, 5)
        step_count = 0
        
        while not current_node.is_terminal and step_count < max_steps:
            if len(current_node.connections) == 0:
                break
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
            step_count += 1
        
        return sequence if len(sequence) > 0 else 'BP'


class ExtraCharsError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = graph.generate_sequence()
        extra_chars = ['B', 'T', 'P', 'S', 'X', 'V']
        num_extra = randint(1, 3)
        
        for _ in range(num_extra):
            sequence += choice(extra_chars)
        
        return sequence


class WrongStartError(SequenceErrorStrategy):    
    
    def generate_error(self, graph):
        invalid_starts = ['T', 'P', 'S', 'X', 'V', 'E']
        wrong_start = choice(invalid_starts)
        
        sequence = ''
        current_node = graph.initial_node
        
        if len(current_node.connections) > 0:
            first_connection = current_node.connections[0]
            current_node = first_connection.node_to
        
        while not current_node.is_terminal and len(sequence) < 10:
            if len(current_node.connections) == 0:
                break
            selected_connection_index = randint(0, len(current_node.connections) - 1)
            sequence += current_node.connections[selected_connection_index].label
            current_node = current_node.connections[selected_connection_index].node_to
        
        return wrong_start + sequence


class MissingMiddleError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = graph.generate_sequence()
        if len(sequence) <= 3:
            return sequence[:-1]
        
        pos_to_remove = randint(1, len(sequence) - 2)
        return sequence[:pos_to_remove] + sequence[pos_to_remove + 1:]


class SwapNonAdjacentError(SequenceErrorStrategy):
    
    def generate_error(self, graph):
        sequence = list(graph.generate_sequence())
        if len(sequence) < 4:
            return ''.join(sequence[::-1])
        
        pos1 = randint(1, len(sequence) - 3)
        pos2 = randint(pos1 + 2, len(sequence) - 1)
        sequence[pos1], sequence[pos2] = sequence[pos2], sequence[pos1]
        
        return ''.join(sequence)

In [4]:
# Usage example: Configure and test error strategies

error_strategies = [
    WrongCharError(),
    InvalidTransitionError(),
    IncompleteSequenceError(),
    ExtraCharsError(),
    WrongStartError(),
    MissingMiddleError(),
    SwapNonAdjacentError()
]

reber_graph.set_error_strategies(error_strategies)

print("=== Valid sequences ===")
for i in range(3):
    valid_seq = reber_graph.generate_sequence()
    print(f"Valid {i+1}: {valid_seq} (validation: {reber_graph.validate_sequence(valid_seq)})")

print("\n=== Invalid sequences ===")
for i in range(5):
    wrong_seq = reber_graph.generate_wrong_sequence()
    print(f"Invalid {i+1}: {wrong_seq} (validation: {reber_graph.validate_sequence(wrong_seq)})")



=== Valid sequences ===
Valid 1: BPVVXVPE (validation: True)
Valid 2: BPTTTVPE (validation: True)
Valid 3: BTSSXSE (validation: True)

=== Invalid sequences ===
Invalid 1: PPVPE (validation: False)
Invalid 2: BTXXTTVPE (validation: True)
Invalid 3: BTSXSE (validation: True)
Invalid 4: BTXSEVS (validation: False)
Invalid 5: BSXTE (validation: False)


In [None]:
import tensorflow as tf
import numpy as np
from typing import Tuple, Optional
import random

tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)

class ReberDataset:
    def __init__(self, graph: ReberGraph, vocab: Optional[dict] = None):
        self.graph = graph
        self.reber_alphabet = ['B', 'T', 'P', 'S', 'X', 'V', 'E']
        self.invalid_chars = ['Z', 'Q', 'W', 'R', 'Y', 'U', 'I', 'O', 'A', 'C', 'D', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N']
        
        if vocab is None:
            self.vocab = {'<PAD>': 0}
            for idx, char in enumerate(self.reber_alphabet, start=1):
                self.vocab[char] = idx
            for idx, char in enumerate(self.invalid_chars, start=len(self.reber_alphabet) + 1):
                self.vocab[char] = idx
        else:
            self.vocab = vocab
        
        self.idx_to_char = {idx: char for char, idx in self.vocab.items()}
        self.vocab_size = len(self.vocab)
    
    def _sequence_to_indices(self, sequence: str) -> list:
        return [self.vocab.get(char, 0) for char in sequence]
    
    def _generate_guaranteed_invalid(self) -> str:
        """Generates a guaranteed invalid sequence"""
        max_attempts = 50
        
        for _ in range(max_attempts):
            sequence = self.graph.generate_wrong_sequence()
            if not self.graph.validate_sequence(sequence):
                return sequence
        
        valid_seq = self.graph.generate_sequence()
        if len(valid_seq) > 2:
            pos = len(valid_seq) // 2
            return valid_seq[:pos] + 'Z' + valid_seq[pos+1:]
        return "BZE"
    
    def _generate_sample(self, is_valid: bool) -> Tuple[list, int]:
        if is_valid:
            sequence = self.graph.generate_sequence()
            label = 1
        else:
            sequence = self._generate_guaranteed_invalid()
            label = 0
        
        tokenized = self._sequence_to_indices(sequence)
        return tokenized, label
    
    def _generator(self, num_samples: int, valid_ratio: float = 0.5):
        num_valid = int(num_samples * valid_ratio)
        num_invalid = num_samples - num_valid
        
        samples_order = [True] * num_valid + [False] * num_invalid
        random.shuffle(samples_order)
        
        for is_valid in samples_order:
            tokenized, label = self._generate_sample(is_valid=is_valid)
            yield (np.array(tokenized, dtype=np.int32), np.array(label, dtype=np.int32))
    
    def generate_dataset(
        self,
        num_samples: int = 10000,
        batch_size: int = 32,
        valid_ratio: float = 0.5,
        max_length: Optional[int] = None,
        shuffle: bool = True,
        shuffle_buffer_size: int = 10000,
        prefetch: bool = True,
        prefetch_size: int = tf.data.AUTOTUNE,
        cache: bool = False,
        repeat: bool = False,
    ) -> tf.data.Dataset:
        dataset = tf.data.Dataset.from_generator(
            lambda: self._generator(num_samples, valid_ratio),
            output_signature=(
                tf.TensorSpec(shape=(None,), dtype=tf.int32),
                tf.TensorSpec(shape=(), dtype=tf.int32)
            )
        )
        
        if shuffle:
            dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True)
        
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=([max_length] if max_length else [None], []),
            padding_values=(self.vocab['<PAD>'], 0),
            drop_remainder=False
        )
        
        if cache:
            dataset = dataset.cache()
        if prefetch:
            dataset = dataset.prefetch(prefetch_size)
        if repeat:
            dataset = dataset.repeat()
        
        return dataset
    
    def generate_train_val_split(
        self,
        num_samples: int = 10000,
        val_ratio: float = 0.2,
        batch_size: int = 32,
        **kwargs
    ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        num_train = int(num_samples * (1 - val_ratio))
        num_val = num_samples - num_train
        
        train_dataset = self.generate_dataset(num_samples=num_train, batch_size=batch_size, **kwargs)
        val_dataset = self.generate_dataset(
            num_samples=num_val, batch_size=batch_size, shuffle=False, cache=False,
            **{k: v for k, v in kwargs.items() if k not in ['shuffle', 'cache']}
        )
        
        return train_dataset, val_dataset


✅ ReberDatasetFixed definido correctamente


In [None]:
# Verify that the dataset is balanced

dataset_generator = ReberDataset(reber_graph)

train_ds, val_ds = dataset_generator.generate_train_val_split(
    num_samples=10000,
    val_ratio=0.2,
    batch_size=32,
    shuffle=True,
    prefetch=True
)


label_counts = {0: 0, 1: 0}
correct_labels = 0
total = 0

for sequences, labels in train_ds.take(50):
    for i in range(len(sequences)):
        seq_str = ''.join([dataset_generator.idx_to_char.get(int(idx), '') 
                          for idx in sequences[i].numpy() if idx != 0])
        label = int(labels[i].numpy())
        is_valid = reber_graph.validate_sequence(seq_str)
        
        label_counts[label] += 1
        total += 1
        
        # Verify consistency: label must match actual validation
        expected_valid = (label == 1)
        if is_valid == expected_valid:
            correct_labels += 1

print(f"Total verified samples: {total}")
print(f"Label 0 (invalid): {label_counts[0]} ({100*label_counts[0]/total:.1f}%)")
print(f"Label 1 (valid): {label_counts[1]} ({100*label_counts[1]/total:.1f}%)")
print(f"\nLabel-validation consistency: {correct_labels}/{total} ({100*correct_labels/total:.1f}%)")


=== VERIFICACIÓN DEL DATASET CORREGIDO ===

Total muestras verificadas: 1600
Label 0 (inválidas): 760 (47.5%)
Label 1 (válidas): 840 (52.5%)

Consistencia label-validación: 1600/1600 (100.0%)


Now lets train a model and thee how it goes

In [7]:
import datetime
import os

%load_ext tensorboard

log_dir = os.path.join("logs", "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(log_dir, exist_ok=True)

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,  # Log weight histograms every epoch
    write_graph=True,  # Visualize the computation graph
    write_images=True,  # Write model weights as images
    update_freq='epoch',  # Log metrics every epoch
    profile_batch=0,  # Disable profiling (set to batch number to enable)
    embeddings_freq=0,  # Frequency to visualize embeddings (0 = disabled)
)

print(f"TensorBoard logs will be saved to: {log_dir}")
print("\nTo start TensorBoard, run in terminal:")
print(f"tensorboard --logdir={os.path.dirname(log_dir)}")
print("\nOr use the magic command:")
print(f"%tensorboard --logdir {os.path.dirname(log_dir)}")


TensorBoard logs will be saved to: logs/fit/20251201-205332

To start TensorBoard, run in terminal:
tensorboard --logdir=logs/fit

Or use the magic command:
%tensorboard --logdir logs/fit


In [8]:
# Example: Training with TensorBoard callback

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(
        input_dim=dataset_generator.vocab_size,
        output_dim=64,
        mask_zero=True
    ),
    tf.keras.layers.LSTM(64, return_sequences=False),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[tensorboard_callback],
    verbose=1
)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


The model performance is excelent. Sequences are not that long and the rules are procedural, there is no need even for a bidirectional encoder decoder architecture that process information in both directions before making an inference. A simple LSTM network is enough to solve this problem