# Simplified Transformer (implementation)
This Jupyter notebook implements the **simplified transformer** described in the provided PDF ("Extracting LTL formulas from Transformers") — a leftmost-hard-max single-head transformer used for Boolean word classification.

Features implemented:
- Token encoding matrix `Wenc` (no two rows equal)
- Leftmost-argmax attention per row: `A = LArgMax(X Q (X K)^T)`
- Layer update: `X' = phi( A (X V) O ) + X`
- Output classification: `True` iff final layer last-row first-column `> theta`
- Options for deterministic weights (seeded) and custom activation

The cells below include the full implementation, a demo, and a few unit tests/examples.


In [None]:
import numpy as np
from typing import List, Tuple, Dict, Any


def leftmost_argmax_rows(matrix: np.ndarray) -> np.ndarray:
    """
    For each row of `matrix`, return a one-hot row vector with a 1 at
    the leftmost index of the maximum value in that row.
    Input: matrix shape (n, n) or (n, m)
    Output: one-hot matrix shape (n, m)
    """
    if matrix.ndim != 2:
        raise ValueError("matrix must be 2D")
    # np.argmax returns index of first occurrence of max along axis
    idx = np.argmax(matrix, axis=1)
    one_hot = np.zeros_like(matrix)
    rows = np.arange(matrix.shape[0])
    one_hot[rows, idx] = 1
    return one_hot

# Activation helpers
def get_activation(name: str):
    if name == 'relu':
        return lambda x: np.maximum(0, x)
    elif name == 'tanh':
        return np.tanh
    elif name == 'identity':
        return lambda x: x
    else:
        raise ValueError(f"Unknown activation: {name}")


In [None]:
class SimplifiedTransformer:
    """
    Implements the simplified transformer as in the PDF.

    - vocab: list of symbols (should include '!' symbol separately if used)
    - d: feature dimension
    - d0: internal attention/value dimension
    - L: number of layers
    - activation: 'relu', 'tanh', or 'identity'
    - theta: classification threshold
    - seed: optionally set random seed for deterministic weights

    The parameters Q/K/V/O per layer and Wenc are numpy arrays.
    """
    def __init__(self,
                 vocab: List[str],
                 d: int = 8,
                 d0: int = 4,
                 L: int = 2,
                 activation: str = 'relu',
                 theta: float = 0.0,
                 seed: int = None,
                 Wenc: np.ndarray = None,
                 Qs: List[np.ndarray] = None,
                 Ks: List[np.ndarray] = None,
                 Vs: List[np.ndarray] = None,
                 Os: List[np.ndarray] = None):
        self.vocab = list(vocab)
        self.vocab_index = {tok: i for i, tok in enumerate(self.vocab)}
        self.d = d
        self.d0 = d0
        self.L = L
        self.activation_name = activation
        self.phi = get_activation(activation)
        self.theta = theta
        self.rng = np.random.RandomState(seed)

        # Encoding matrix: |vocab| x d, no two rows equal
        if Wenc is None:
            # construct deterministic rows with small variations to avoid equality
            base = self.rng.randn(len(self.vocab), d)
            # ensure rows are unique by adding small offsets
            for i in range(len(self.vocab)):
                base[i] += (i * 1e-3)
            self.Wenc = base
        else:
            assert Wenc.shape == (len(self.vocab), d)
            self.Wenc = Wenc

        # Per-layer matrices
        def rand_mat(shape):
            return self.rng.randn(*shape)

        self.Qs = Qs if Qs is not None else [rand_mat((d, d0)) for _ in range(L)]
        self.Ks = Ks if Ks is not None else [rand_mat((d, d0)) for _ in range(L)]
        self.Vs = Vs if Vs is not None else [rand_mat((d, d0)) for _ in range(L)]
        self.Os = Os if Os is not None else [rand_mat((d0, d)) for _ in range(L)]

    def encode(self, word: List[str]) -> np.ndarray:
        """Return X0: n x d matrix for the input admissible word (list of tokens)."""
        rows = []
        for tok in word:
            if tok not in self.vocab_index:
                raise ValueError(f"Token '{tok}' not in vocab")
            rows.append(self.Wenc[self.vocab_index[tok]])
        return np.vstack(rows)

    def forward(self, word: List[str], return_all: bool = False) -> Dict[str, Any]:
        """
        Run the transformer on an admissible word (list of tokens).
        Returns a dict with keys:
          - 'X_layers': list of X matrices (including X0)
          - 'A_layers': list of attention one-hot matrices per layer
          - 'classification': boolean result
          - 'score': final value XL[last_row, 0]
        """
        X = self.encode(word)  # X0
        n = X.shape[0]
        X_layers = [X.copy()]
        A_layers = []

        for ell in range(self.L):
            Q = self.Qs[ell]
            K = self.Ks[ell]
            V = self.Vs[ell]
            O = self.Os[ell]

            # compute scores: (X Q) (X K)^T  -> shape n x n
            scores = (X.dot(Q)).dot((X.dot(K)).T)
            A = leftmost_argmax_rows(scores)
            A_layers.append(A)

            # compute A (X V) O
            XV = X.dot(V)        # n x d0
            A_XV = A.dot(XV)     # n x d0
            Y = A_XV.dot(O)      # n x d
            X = self.phi(Y) + X  # residual
            X_layers.append(X.copy())

        final_score = X[-1, 0]  # XL[n,1] in 1-based indexing
        classification = bool(final_score > self.theta)

        out = {
            'X_layers': X_layers,
            'A_layers': A_layers,
            'classification': classification,
            'score': final_score,
            'word': word
        }
        if return_all:
            return out
        else:
            return {'classification': classification, 'score': final_score}

    def pretty_print_trace(self, trace: Dict[str, Any]):
        print(f"Input word: {' '.join(trace['word'])}")
        for i, X in enumerate(trace['X_layers']):
            print(f"\nX layer {i} (shape {X.shape}):")
            print(np.round(X, 4))
            if i < len(trace['A_layers']):
                print(f"A layer {i+1}:")
                print(trace['A_layers'][i])
        print(f"\nFinal score (last-row, first-col): {trace['score']:.6f}")
        print(f"Classification (> {self.theta}): {trace['classification']}")


In [None]:
# Demo: build transformer and run on a few example admissible words
vocab = ['a', 'b', 'c', '!']
transformer = SimplifiedTransformer(vocab=vocab, d=8, d0=4, L=3, activation='relu', theta=0.0, seed=42)

examples = [
    ['a','a','!'],
    ['a','b','!'],
    ['b','b','!'],
    ['c','a','b','!']
]

for w in examples:
    trace = transformer.forward(w, return_all=True)
    transformer.pretty_print_trace(trace)
    print('\n' + '-'*60 + '\n')


In [None]:
# Unit-test style example: craft Wenc and layer matrices so attention picks a predictable token
vocab = ['x', 'y', '!']
# choose d=3, d0=2 for clarity
Wenc = np.array([[1.0, 0.0, 0.0],  # x
                 [0.0, 1.0, 0.0],  # y
                 [0.0, 0.0, 1.0]]) # !

# design Q and K so that tokens with matching one-hot attend to each other
Q = np.array([[1.0,0.0],[0.0,1.0],[0.0,0.0]])  # d x d0
K = Q.copy()
V = np.array([[1.0,0.0],[0.0,1.0],[0.0,0.0]])  # identity-ish
O = np.array([[1.0,0.0,0.0],[0.0,1.0,0.0]])      # d0 x d -> will reconstruct

# single layer transformer with known behavior
T = SimplifiedTransformer(vocab=vocab, d=3, d0=2, L=1, activation='identity', theta=0.5, seed=0,
                          Wenc=Wenc, Qs=[Q], Ks=[K], Vs=[V], Os=[O])

# if word ends with '!' and we want classifier to check last-row first-col, we can craft
w1 = ['x', '!']
w2 = ['y', '!']

trace1 = T.forward(w1, return_all=True)
trace2 = T.forward(w2, return_all=True)

print('Trace for', w1)
T.pretty_print_trace(trace1)
print('\nTrace for', w2)
T.pretty_print_trace(trace2)

print('\nNote: with identity activation and the chosen matrices, you can reason about which positions each row attends to and how the final value changes.')


## Notes
- This notebook follows the exact formulas in the PDF: attention uses leftmost-argmax, single-head.
- No positional encoding was added (as in the PDF).
- You can experiment by changing seeds, layer count `L`, dimensions `d` and `d0`, and activation.

## How to run
1. Download the notebook file and open it in Jupyter (or JupyterLab / VS Code).
2. Run all cells.

Enjoy — if you'd like, I can:
- add more visualizations (plot attention patterns),
- convert this to a ready-to-run `.py` module or a hosted Google Colab link,
- or extend it to multi-head / softmax attention.
