## Fine-tuning a model to do the parenthesis balancing task

## Setup

In [1]:
FORCE_CPU = False
SEED = 2384
BASE_MODEL_NAME = "gelu-1l"

TEXT_DATASET_FILE = "../../data/paren-balancing/single_line.csv"

OPEN_PAREN_STR_TOKENS = ["("]
CLOSE_PAREN_STR_TOKENS = [")"]

TEST_DATASET_SIZE = 0.1
VALIDATION_DATASET_SIZE = 0.1

BATCH_SIZE = 256
EPOCHS = 10
LEARNING_RATE = 1e-4
LR_SCHEDULER_PATIENCE = 1000

In [2]:
from copy import deepcopy
from math import sqrt

from IPython.display import display

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import einops

import numpy as np

import pandas as pd

from tqdm import tqdm

import plotly.express as px

from transformer_lens import HookedTransformer
from transformer_lens.components import Unembed

In [3]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cuda


In [4]:
import plotly.io as pio
pio.renderers.default = "colab+vscode"

## Model

In [5]:
base_model = HookedTransformer.from_pretrained(BASE_MODEL_NAME, device=device)

model_config = deepcopy(base_model.cfg)
model_config.d_vocab_out = 2

model = HookedTransformer(cfg=model_config)

model_state_dict = deepcopy(base_model.state_dict())
model_state_dict["unembed.W_U"] = torch.empty(
    model_config.d_model, model_config.d_vocab_out
).to(device)
model_state_dict["unembed.b_U"] = torch.empty(model_config.d_vocab_out).to(device)
nn.init.uniform_(
    model_state_dict["unembed.W_U"],
    -1/sqrt(model_config.d_model),
    1/sqrt(model_config.d_model),
)
nn.init.uniform_(
    model_state_dict["unembed.b_U"],
    -1/sqrt(model_config.d_model),
    1/sqrt(model_config.d_model),
)

model.load_state_dict(model_state_dict)

Loaded pretrained model gelu-1l into HookedTransformer


<All keys matched successfully>

In [6]:
pad_token_id = base_model.tokenizer.pad_token_id
print(pad_token_id)

2


## Dataset

In [7]:
open_paren_tokens = model.to_tokens(
    OPEN_PAREN_STR_TOKENS, prepend_bos=False, move_to_device=False, truncate=False
)
open_paren_tokens = [
    open_paren_tokens[i, 0].item() for i in range(open_paren_tokens.shape[0])
]

closed_paren_tokens = model.to_tokens(
    CLOSE_PAREN_STR_TOKENS, prepend_bos=False, move_to_device=False, truncate=False
)
closed_paren_tokens = [
    closed_paren_tokens[i, 0].item() for i in range(closed_paren_tokens.shape[0])
]

In [8]:
text_data = pd.read_csv(TEXT_DATASET_FILE)
text_data_tokenised = base_model.to_tokens(text_data["text"].values, move_to_device=False)

In [9]:
open_bracket = torch.isin(text_data_tokenised, torch.tensor(open_paren_tokens))
closed_bracket = torch.isin(text_data_tokenised, torch.tensor(closed_paren_tokens))
bracket_values = torch.zeros_like(text_data_tokenised, dtype=torch.long)
bracket_values = bracket_values + open_bracket.long() - closed_bracket.long()
cumsum = torch.cumsum(bracket_values, dim=-1)
output_data = (cumsum > 0).to(dtype=torch.long)

In [10]:
loss_mask = text_data_tokenised != pad_token_id

In [11]:
shuffled_indices = torch.randperm(text_data_tokenised.shape[0])
train_indices = shuffled_indices[
    : int(
        text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE - VALIDATION_DATASET_SIZE)
    )
]
validation_indices = shuffled_indices[
    int(
        text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE - VALIDATION_DATASET_SIZE)
    ) : int(text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE))
]
test_indices = shuffled_indices[
    int(text_data_tokenised.shape[0] * (1 - TEST_DATASET_SIZE)) :
]

train_dataset = TensorDataset(
    text_data_tokenised[train_indices],
    output_data[train_indices],
    loss_mask[train_indices],
)
validation_dataset = TensorDataset(
    text_data_tokenised[validation_indices],
    output_data[validation_indices],
    loss_mask[validation_indices],
)
test_dataset = TensorDataset(
    text_data_tokenised[test_indices],
    output_data[test_indices],
    loss_mask[test_indices],
)

## Training

In [12]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
)
validation_dataloader = DataLoader(
    validation_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
)

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", patience=1000, verbose=True
)
losses = np.empty(EPOCHS)
accuracies = np.empty(EPOCHS)

for epoch in range(EPOCHS):
    total_loss = 0.0
    total_agreement = 0.0

    iterator = tqdm(
        train_dataloader,
        total=len(train_dataloader),
        desc=f"Epoch {epoch + 1}/{EPOCHS}",
    )
    for tokens, gold_output, loss_mask in iterator:
        tokens = tokens.to(device)
        gold_output = gold_output.to(device)
        loss_mask = loss_mask.to(device)

        optimizer.zero_grad()

        output = model(tokens)
        output_rearranged = einops.rearrange(output, "batch seq out -> batch out seq")
        loss = F.cross_entropy(output_rearranged, gold_output, reduction="none")
        loss = loss[loss_mask].mean()
        loss.backward()
        optimizer.step()
        scheduler.step(loss)

        total_loss += loss.item()
        with torch.no_grad():
            total_agreement += (
                (torch.argmax(output, dim=-1) == gold_output)[loss_mask]
                .float()
                .mean()
                .item()
            )

    losses[epoch] = total_loss / len(train_dataloader)
    accuracies[epoch] = total_agreement / len(train_dataloader)

    print(f"Loss: {losses[epoch]:.4f}, Accuracy: {accuracies[epoch]:.4f}")

Epoch 1/10:   1%|‚ñè         | 9/637 [00:39<46:15,  4.42s/it]


KeyboardInterrupt: 