In [1]:
!pip install --upgrade datasets huggingface_hub

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-0.33.2-py3-none-any.whl.metadata (14 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading huggingface_hub-0.33.2-py3-none-any.whl (515 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.4/515.4 kB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, huggingface_hub, datasets
  Attempting uninstall: fsspec
    Found existing i

Define a Decoder Only Transformer Using the TRA Attention Mechanism

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter('weight', None)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.weight is not None:
            output = output * self.weight
        return output

    def extra_repr(self) -> str:
        return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'


class SwiGLU(nn.Module):
    def __init__(self):
        super(SwiGLU, self).__init__()

    def forward(self, x1, x2):
        return x1 * F.silu(x2)

class MLPWithSwiGLU(nn.Module):
    def __init__(self, e_dim, dropout, middle_factor=4):
        super(MLPWithSwiGLU, self).__init__()
        self.linear1 = nn.Linear(e_dim, e_dim * middle_factor)  # Gated mechanism
        self.swiglu = SwiGLU()
        self.linear2 = nn.Linear(middle_factor//2 * e_dim, e_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_proj = self.linear1(x)
        x1, x2 = x_proj.chunk(2, dim=-1)  # Split the projection into two parts
        x = self.dropout(self.swiglu(x1, x2))
        x = self.linear2(x)
        return x



class TRA(nn.Module):
    def __init__(self, e_dim, n_heads, dropout, max_len=5000):
        super(TRA, self).__init__()
        self.E = e_dim
        self.e = e_dim // n_heads
        self.h = n_heads
        self.tokeys =  nn.Linear(e_dim, e_dim, bias=False)
        self.toqueries = nn.Linear(e_dim, e_dim, bias=False)
        self.tovalues = nn.Linear(e_dim, e_dim, bias=False)
        self.delta_proj = nn.Linear(e_dim, self.h)
        self.c_out = nn.Linear(e_dim, e_dim)
        # use qk norm but no scale to save params
        self.qk_norm = RMSNorm(self.e, elementwise_affine=False)
        self.register_buffer("bias", torch.triu(torch.ones(max_len, max_len), diagonal=1).view(1, 1,max_len, max_len).to(torch.bool))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, s, _ = x.shape # batch_size, seq_len, e_dim
        # transform q, k, v
        q = self.toqueries(x).view(b, s, self.h, self.e).transpose(1, 2)
        k = self.tokeys(x).view(b, s, self.h, self.e).transpose(1, 2)
        v = self.tovalues(x).view(b, s, self.h, self.e).transpose(1, 2)
        # qk norm (no scale -> can share)
        q = self.qk_norm(q)
        k = self.qk_norm(k)
        # compute attn dot
        S = (q @ k.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.e, dtype=torch.float, device=x.device))
        S = S.masked_fill(self.bias[:,:,:s,:s], float('-1e11'))
        S = F.relu(S)
        mask = (S == 0)
        # compute positional weights D
        delta = F.logsigmoid(self.delta_proj(x)).transpose(1,2).unsqueeze(-1).contiguous().view(b, self.h, s, 1)
        D = (~mask).float() * delta
        D = D.flip(-1).cumsum(-1).flip(-1)
        A = self.dropout(S + D)
        A[mask] = -1e11 # mask: casual + threshold
        A = A.softmax(dim=-1)
        # protect against leaks where mask all = 0 from no-op attention
        A = A.masked_fill(mask.all(-1).unsqueeze(-1), 0)
        # compute output
        out = A @ v
        out = out.transpose(1, 2).contiguous().view(b, s, self.E)
        return self.c_out(out)


class TraBlock(nn.Module):
    def __init__(self, e_dim, n_heads, dropout=0.01):
        super(TraBlock, self).__init__()
        self.E = e_dim
        self.e = e_dim // n_heads
        self.h = n_heads
        self.dropout = nn.Dropout(p=dropout)
        self.ln1 = RMSNorm(e_dim)
        self.attn = TRA(e_dim, n_heads, dropout)
        self.mlp = MLPWithSwiGLU(e_dim, dropout)
        self.ln2 = RMSNorm(e_dim)

    def forward(self, x):
        x = self.ln1(x)
        x = x + self.attn(x)
        x = self.ln2(x)
        x = x + self.mlp(x)
        return x


class TraDecoder(nn.Module):
    def __init__(self, e_dim, n_layers, n_heads, vocab_size, max_dist, dropout=0.01):
        super(TraDecoder, self).__init__()
        # hparams
        self.e = e_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.v = vocab_size

        # components
        self.dropout = nn.Dropout(p=dropout)
        self.embedding = nn.Embedding(self.v, e_dim, padding_idx=0)
        self.layers = nn.ModuleList([TraBlock(e_dim, n_heads,) for i in range(n_layers)])
        self.outln = RMSNorm(e_dim)
        self.out = nn.Linear(e_dim, vocab_size, bias=False)
        self._init_weights()

    def _init_weights(self):
        nn.init.uniform_(self.embedding.weight, -0.01, 0.01)
        nn.init.normal_(self.out.weight, std=0.02)

    def forward(self, e):
        # e: input sequence,
        x = self.dropout(self.embedding(e))
        for layer in self.layers:
            x = layer(x)
        return self.out(self.outln(x))

Load the Flip-Flop Language Modeling Dataset

In [3]:
from datasets import load_dataset
ds = load_dataset("synthseq/flipflop")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

(…)-00000-of-00002-b4ca324082d96883.parquet:   0%|          | 0.00/160M [00:00<?, ?B/s]

(…)-00001-of-00002-bf5f777704418c83.parquet:   0%|          | 0.00/160M [00:00<?, ?B/s]

(…)-00000-of-00001-fec0c03d88b56508.parquet:   0%|          | 0.00/3.21M [00:00<?, ?B/s]

(…)-00000-of-00001-cd57636c8e1ff5a5.parquet:   0%|          | 0.00/831k [00:00<?, ?B/s]

(…)-00000-of-00001-1fcf65938d0f40dd.parquet:   0%|          | 0.00/29.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1600000 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/16000 [00:00<?, ? examples/s]

Generating val_dense split:   0%|          | 0/4000 [00:00<?, ? examples/s]

Generating val_sparse split:   0%|          | 0/160000 [00:00<?, ? examples/s]

Set-Up Our Training Scripts

In [4]:
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import get_cosine_schedule_with_warmup
from torch.nn.utils.rnn import pad_sequence
import random
import math
import torch.optim as optim


# vocab for flip-flops
def make_vocab():
    vocab = {}
    vocab['[PAD]'] = 0
    vocab['w'] = 1
    vocab['r'] = 2
    vocab['i'] = 3
    vocab['0'] = 4
    vocab['1'] = 5
    return vocab


def full_sequence_accuracy(y_pred, y_true, mask):
    y_pred = y_pred.argmax(dim=-1)
    y_pred = y_pred[:, :-1]
    y_true = y_true[:, 1:]
    y_pred[~mask[:, :-1]] = y_true[~mask[:, :-1]]
    return (y_pred == y_true).all(dim=-1).sum().item() / y_true.size(0)



class FFLDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, text_column='text'):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.text_column = text_column

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        line = item[self.text_column]
        tl = [self.tokenizer[x] for x in line]
        targs = [0 if x != 'r' else 1 for x in line] # select read instructions as targets
        return torch.tensor(tl), torch.tensor(targs)

def collate_fn(batch):
    data = torch.stack([x[0] for x in batch])
    targs = torch.stack([x[1] for x in batch])
    return data, targs

def create_dataloader(hf_split_dataset, tokenizer, batch_size, text_column='text'):
    dataset = FFLDataset(hf_split_dataset, tokenizer, text_column=text_column)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    return dataloader

def train_epoch(model, dataloader, device, optimizer, scheduler):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        input_ids, selecta = batch
        input_ids, selecta = input_ids.to(device), selecta.bool().to(device)
        out = model(input_ids)
        # backprop the loss
        loss = F.cross_entropy(out[:, :-1, :][selecta[:, :-1]], input_ids[:, 1:][selecta[:, :-1]])
        loss.backward()
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.item()
        # calculate accuracy
        logits = out
        acc = full_sequence_accuracy(logits, input_ids, selecta)
        epoch_acc += acc
    return epoch_loss / len(dataloader), epoch_acc / len(dataloader)


@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    for batch in tqdm(dataloader):
        input_ids, selecta = batch
        input_ids, selecta = input_ids.to(device), selecta.bool().to(device)
        out = model(input_ids)
        loss = F.cross_entropy(out[:, :-1, :][selecta[:, :-1]], input_ids[:, 1:][selecta[:, :-1]])
        epoch_loss += loss.item()
        # calculate accuracy
        logits = out
        epoch_acc += full_sequence_accuracy(logits, input_ids, selecta)
    return epoch_loss / len(dataloader), epoch_acc / len(dataloader),


Train Our TRA Model on Flip Flop Language Modelling

*   We will use a four layer four head model
*   Should take about 2 hours on T4 GPU, 20 minutes on A100 - when compiled otherwise about twice as long
*   We will only train for one epoch at batch size 64 (if you increase batch size you may need more epochs)
*   You should see the model fully generalise across all splits - otherwise please let us know! We tested 50 different random seeds and it generalised every time..





In [5]:
# init model
tokenizer = make_vocab()
model = TraDecoder(256, 4, 4, len(list(tokenizer.keys())), 5000)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)
# use torch compile to double the speed
print('compiling')
torch.set_float32_matmul_precision('high')
model = torch.compile(model)
# set up dataloaders
BATCH_SIZE = 64
train_dataloader = create_dataloader(ds['train'], tokenizer, BATCH_SIZE)
val_iid_dataloader = create_dataloader(ds['val'], tokenizer, BATCH_SIZE)
val_sparse_dataloader = create_dataloader(ds['val_sparse'], tokenizer, BATCH_SIZE)
val_dense_dataloader = create_dataloader(ds['val_dense'], tokenizer, BATCH_SIZE)
# set optimiser and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader))
# train
loss, acc = train_epoch(model, train_dataloader, device, optimizer, scheduler)
print(f'Train Loss: {loss}')
print(f'Train Acc: {acc}')
val_loss, val_acc = evaluate(model, val_iid_dataloader, device)
print(f'Val Loss: {val_loss}')
print(f'Val Acc: {val_acc}')
sparse_loss, sparse_acc = evaluate(model, val_sparse_dataloader, device)
print(f'Sparse Loss: {sparse_loss}')
print(f'Sparse Acc: {sparse_acc}')
dense_loss, dense_acc = evaluate(model, val_dense_dataloader, device)
print(f'Val Dense Loss: {dense_loss}')
print(f'Val Dense Acc: {dense_acc}')

cuda
compiling


  0%|          | 0/25000 [00:00<?, ?it/s]W0707 14:34:45.620000 1183 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode
100%|██████████| 25000/25000 [2:06:52<00:00,  3.28it/s]


Train Loss: 0.0014076740333366208
Train Acc: 0.997623125


100%|██████████| 250/250 [00:43<00:00,  5.71it/s]


Val Loss: 1.1866352303968597e-07
Val Acc: 1.0


100%|██████████| 2500/2500 [04:50<00:00,  8.60it/s]


Sparse Loss: 1.1920191242609235e-07
Sparse Acc: 1.0


100%|██████████| 63/63 [00:25<00:00,  2.42it/s]

Val Dense Loss: 1.1889904132860667e-07
Val Dense Acc: 1.0



