## Fine-tuning a model to do the parenthesis balancing task

## Setup

In [1]:
FORCE_CPU = True
SEED = 2384
BASE_MODEL_NAME = "gelu-1l"

TEXT_DATASET_FILE = "../../data/paren-balancing/single_line.csv"

OPEN_PAREN_STR_TOKENS = ["("]
CLOSE_PAREN_STR_TOKENS = [")"]

TEST_DATASET_SIZE = 0.1
VALIDATION_DATASET_SIZE = 0.1

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
LR_SCHEDULER_PATIENCE = 1000

In [2]:
from IPython.display import display

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import pandas as pd

import plotly.express as px

from transformer_lens import HookedTransformer

In [3]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


In [4]:
import plotly.io as pio
pio.renderers.default = "colab+vscode"

## Model

In [5]:
base_model = HookedTransformer.from_pretrained(BASE_MODEL_NAME, device=device)

Loaded pretrained model gelu-1l into HookedTransformer


In [6]:
base_model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_re

## Dataset

In [7]:
open_paren_tokens = base_model.to_tokens(
    OPEN_PAREN_STR_TOKENS, prepend_bos=False, move_to_device=False, truncate=False
)
open_paren_tokens = [
    open_paren_tokens[i, 0].item() for i in range(open_paren_tokens.shape[0])
]

closed_paren_tokens = base_model.to_tokens(
    CLOSE_PAREN_STR_TOKENS, prepend_bos=False, move_to_device=False, truncate=False
)
closed_paren_tokens = [
    closed_paren_tokens[i, 0].item() for i in range(closed_paren_tokens.shape[0])
]

In [8]:
text_data = pd.read_csv(TEXT_DATASET_FILE)
text_data_tokenised = base_model.to_tokens(text_data["text"].values, move_to_device=False)

In [10]:
open_bracket = torch.isin(text_data_tokenised, torch.tensor(open_paren_tokens))
closed_bracket = torch.isin(text_data_tokenised, torch.tensor(closed_paren_tokens))
bracket_values = torch.zeros_like(text_data_tokenised, dtype=torch.int8)
bracket_values = bracket_values + open_bracket.long() - closed_bracket.long()
cumsum = torch.cumsum(bracket_values, dim=-1)
output_data = (cumsum > 0).to(dtype=torch.int8)

In [11]:
shuffled_indices = torch.randperm(text_data_tokenised.shape[0])
train_indices = shuffled_indices[
    : int(
        text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE - VALIDATION_DATASET_SIZE)
    )
]
validation_indices = shuffled_indices[
    int(
        text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE - VALIDATION_DATASET_SIZE)
    ) : int(text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE))
]
test_indices = shuffled_indices[
    int(text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE)) :
]

train_dataset = TensorDataset(
    text_data_tokenised[train_indices], output_data[train_indices]
)
validation_dataset = TensorDataset(
    text_data_tokenised[validation_indices], output_data[validation_indices]
)
test_dataset = TensorDataset(
    text_data_tokenised[test_indices], output_data[test_indices]
)

In [12]:
train_dataset[0]

(tensor([    1,   342,  1800,  4432,    65,  9468,    65, 18349,    10,  3075,
            65, 42388,    14, 13997,    65, 12435,    85,    14,   770,    65,
         42388,    11,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  