# Synthetic dataset for parenthesis balancing

## Setup

In [60]:
FORCE_CPU = True
SEED = 2384
MODEL_NAME = "gelu-1l"

DATASET_SIZE = 100000
DATASET_SAMPLE_LENGTH = 31
DATASET_MAX_DEPTH = 3

DATA_FILE = "../../data/paren-balancing/synthetic_a_l31_d3.csv"

In [61]:
import math

import torch

import pandas as pd

import numpy as np

import plotly.express as px

from transformer_lens import HookedTransformer

In [62]:
np.random.seed(SEED)

In [63]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


In [64]:
model = HookedTransformer.from_pretrained(MODEL_NAME, device=device)

Loaded pretrained model gelu-1l into HookedTransformer


In [65]:
alphabet = "abcdefghijklmnopqrstuvwxyz"

## Tokens

In [66]:
d_vocab = model.tokenizer.vocab_size
str_tokens = model.to_str_tokens(torch.arange(d_vocab), prepend_bos=False)

In [67]:
paren_tokens = [str_token for str_token in str_tokens if "(" in str_token or ")" in str_token]
for paren_token in paren_tokens:
    for letter in alphabet + alphabet.upper():
        if letter in paren_token:
            print(paren_token)

## Dataset

In [68]:
num_brackets = math.ceil(DATASET_SAMPLE_LENGTH / 2)
num_letters = DATASET_SAMPLE_LENGTH - num_brackets

In [69]:
bracket_delta = 2 * np.random.randint(0, 2, (DATASET_SIZE, num_brackets)) - 1
bracket_cumsum = np.cumsum(bracket_delta, axis=1)
bracket_cumsum = np.abs(bracket_cumsum)
bracket_cumsum = DATASET_MAX_DEPTH - np.abs(DATASET_MAX_DEPTH - bracket_cumsum)
bracket_delta = np.diff(bracket_cumsum, axis=1, prepend=0)
assert np.all(np.cumsum(bracket_delta, axis=1) == bracket_cumsum)

In [70]:
brackets = np.where(bracket_delta == 1, "(", ")")
brackets

array([['(', '(', '(', ..., '(', '(', ')'],
       ['(', ')', '(', ..., '(', ')', ')'],
       ['(', ')', '(', ..., ')', '(', ')'],
       ...,
       ['(', ')', '(', ..., ')', '(', '('],
       ['(', ')', '(', ..., ')', '(', ')'],
       ['(', ')', '(', ..., '(', ')', '(']], dtype='<U1')

In [71]:
alphabet_full_np = np.array(list(alphabet + alphabet.upper()))
letters = alphabet_full_np[
    np.random.randint(0, len(alphabet_full_np), (DATASET_SIZE, num_letters))
]
letters

array([['Q', 'n', 'f', ..., 'O', 'x', 'L'],
       ['a', 'f', 'w', ..., 'S', 'H', 'v'],
       ['l', 'j', 'Y', ..., 'd', 'q', 'h'],
       ...,
       ['A', 'a', 'A', ..., 'b', 'n', 'd'],
       ['f', 'm', 'p', ..., 'y', 'O', 'D'],
       ['d', 'L', 'V', ..., 'k', 's', 't']], dtype='<U1')

In [72]:
interleaved = np.empty((DATASET_SIZE, DATASET_SAMPLE_LENGTH), dtype="U1")
interleaved[:, ::2] = brackets
interleaved[:, 1::2] = letters
interleaved

array([['(', 'Q', '(', ..., '(', 'L', ')'],
       ['(', 'a', ')', ..., ')', 'v', ')'],
       ['(', 'l', ')', ..., '(', 'h', ')'],
       ...,
       ['(', 'A', ')', ..., '(', 'd', '('],
       ['(', 'f', ')', ..., '(', 'D', ')'],
       ['(', 'd', ')', ..., ')', 't', '(']], dtype='<U1')

In [73]:
dataset = ["".join(row) for row in interleaved]
dataset[:10]

['(Q(n(f)w)B)l(Z)E(b(z(a)G)O(x(L)',
 '(a)f(w(g)M)p(R(Z)d)v(T)i(S(H)v)',
 '(l)j(Y)f(S)x(v(O)d)A(P(P)d)q(h)',
 '(C)H(w(j)F(p)r(D(t)U)e(B(J)b(r)',
 '(P(j(e)i(s)e)y(V(U)b(r)y(L)N(N)',
 '(K(a(k)R(q)s(n)C)D(v)b(Q(Y)H)m(',
 '(x(K)Q)j(t(S(g)N(h)O)u(u)B)J(q)',
 '(B)X(k)L(P)c(C(k(s)m)b)n(M)N(M)',
 '(O(V)x)R(U(E(A)d(t)w(t)l)x(K)P(',
 '(e(J)W(f)J)s(G)W(G)h(k)c(W(d)u)']

## Checking tokenization

In [74]:
tokenized = model.to_tokens(dataset)

In [75]:
assert not torch.any(tokenized[:,-1] == 2)

## Saving

In [76]:
dataset_df = pd.DataFrame(dataset, columns=["text"])
dataset_df.head()

Unnamed: 0,text
0,(Q(n(f)w)B)l(Z)E(b(z(a)G)O(x(L)
1,(a)f(w(g)M)p(R(Z)d)v(T)i(S(H)v)
2,(l)j(Y)f(S)x(v(O)d)A(P(P)d)q(h)
3,(C)H(w(j)F(p)r(D(t)U)e(B(J)b(r)
4,(P(j(e)i(s)e)y(V(U)b(r)y(L)N(N)


In [77]:
dataset_df.to_csv(DATA_FILE, index=False)