In [1]:
from subspace_partition.subspace_partition import (
    run_subspace_partition,
    SubspacePartitionConfig,
)

import transformer_lens
from pathlib import Path
import copy_transformer.data

In [2]:
EMBEDDNING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

EXPERIMENT_NAME = "test_experiment"
OUTPUT_DIR = Path("test_outputs")

import shutil

shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

In [3]:
from transformers import PreTrainedTokenizer
from typing import List, Optional


class SimpleCharTokenizer(PreTrainedTokenizer):
    """Super simple character tokenizer that takes a list of chars as alphabet"""

    def __init__(self, alphabet: List[str], **kwargs):
        """
        Args:
            alphabet: List of characters to use as vocabulary
        """
        # Store alphabet
        self.alphabet = alphabet

        # Create vocab mapping: char -> id
        self.char_to_id = {char: idx for idx, char in enumerate(alphabet)}
        self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}

        super().__init__(**kwargs)

        # Add special tokens to vocab after parent init
        special_tokens = [
            (self.bos_token, self.bos_token_id),
            (self.eos_token, self.eos_token_id),
            (self.unk_token, self.unk_token_id),
            (self.pad_token, self.pad_token_id),
        ]
        for token, token_id in special_tokens:
            if token and token_id is not None:
                self.char_to_id[token] = token_id
                self.id_to_char[token_id] = token

    @property
    def vocab_size(self) -> int:
        return len(self.char_to_id)

    def get_vocab(self):
        return self.char_to_id.copy()

    def _tokenize(self, text: str) -> List[str]:
        """Split text into individual characters"""
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        """Convert character to ID"""
        return self.char_to_id.get(token, self.char_to_id.get(self.unk_token, 0))

    def _convert_id_to_token(self, index: int) -> str:
        """Convert ID to character"""
        return self.id_to_char.get(index, self.unk_token or "")

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Join characters back into string"""
        return "".join(tokens)

    def save_vocabulary(
        self, save_directory: str, filename_prefix: Optional[str] = None
    ):
        """Save vocabulary to file"""
        import json
        import os

        if not os.path.isdir(save_directory):
            os.makedirs(save_directory)

        vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
        )

        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.char_to_id, f, ensure_ascii=False, indent=2)

        return (vocab_file,)

In [4]:
tokenizer = SimpleCharTokenizer(
    alphabet=VOCABULARY,
    bos_token=">",
    eos_token="<",
    unk_token="?",
    pad_token="_",
    name_or_path="custom",
    add_bos_token=True,
)

model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDNING_DIM,
    d_head=EMBEDDNING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)
model_state_dict_path = Path("copy_transformer.pt")

subspace_partition_config = SubspacePartitionConfig(
    exp_name=EXPERIMENT_NAME,
    output_dir=OUTPUT_DIR,
    model_config=model_config,
    model_weights_path=model_state_dict_path,
    act_sites=["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"],
    tokenizer=tokenizer,
    dataset=copy_transformer.data.IterablePureRepeatingPatternDataset(
        num_samples=10_000,
        vocabulary=VOCABULARY,
        context_length=CONTEXT_LENGTH,
        max_pattern_length=7,
    ),
    max_steps=500,
)

In [5]:
run_subspace_partition(subspace_partition_config)

training for blocks.0.hook_resid_post


 40%|████      | 201/500 [00:32<00:46,  6.40it/s]

{'R_grad_norm': 0.015745067843236028, 'training_loss': 0.10712239142507314}


 80%|████████  | 401/500 [01:03<00:16,  6.11it/s]

{'R_grad_norm': 0.016980729922652246, 'training_loss': 0.10401578504592181}


100%|██████████| 500/500 [01:19<00:00,  6.26it/s]



finish training (500)
saving.. test_outputs/test_experiment
evaluating (25 steps)...
 ******* eval result *******
mean (weighted) 0.1004851758480072
mean (unweighted) 0.1004851758480072
tensor([0.0958, 0.1052])
training for blocks.1.hook_resid_post
 ******* eval result *******
mean (weighted) 0.1004851758480072
mean (unweighted) 0.1004851758480072
tensor([0.0958, 0.1052])
training for blocks.1.hook_resid_post


 40%|████      | 201/500 [00:32<00:50,  5.91it/s]

{'R_grad_norm': 0.04853655453771353, 'training_loss': 0.34381822042167187}


 80%|████████  | 401/500 [01:07<00:16,  6.16it/s]

{'R_grad_norm': 0.07641626667231322, 'training_loss': 0.331877730935812}


100%|██████████| 500/500 [01:23<00:00,  5.98it/s]



finish training (500)
saving.. test_outputs/test_experiment
evaluating (25 steps)...
 ******* eval result *******
mean (weighted) 0.32304659485816956
mean (unweighted) 0.32304659485816956
tensor([0.2947, 0.3514])
 ******* eval result *******
mean (weighted) 0.32304659485816956
mean (unweighted) 0.32304659485816956
tensor([0.2947, 0.3514])
