In [None]:
def in_google_colab():
  try:
    import google.colab
    return True
  except ImportError:
    return False

if in_google_colab():
  !git clone https://github.com/Sappique/SubspacePartition.git
  %cd SubspacePartition
  # Install uv
  !curl -LsSf https://astral.sh/uv/install.sh | sh
  import os
  os.environ['PATH'] = f"/root/.cargo/bin:{os.environ['PATH']}"

  # install the cpu version
  !uv pip install --system -r pyproject.toml
  # then install the cuda version
  !uv pip install --system --reinstall \
    "torch>=2.5.0" \
    "torchvision>=0.20.0" \
    --index-url https://download.pytorch.org/whl/cu121

In [None]:
import os
os.kill(os.getpid(), 9)

In [None]:
def in_google_colab():
  try:
    import google.colab
    return True
  except ImportError:
    return False
if in_google_colab():
  %cd SubspacePartition

In [None]:
from subspace_partition.subspace_partition import (
    run_subspace_partition,
    SubspacePartitionConfig,
)

import transformer_lens
from pathlib import Path
import copy_transformer.data
import copy_transformer.tokenizer

In [None]:
EMBEDDING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

MAX_PATTERN_LENGTH = 16

MAX_STEPS = 20_000

EXPERIMENT_NAME = "copy_transformer_truly_sensible_sp_parameters2"
OUTPUT_DIR = Path("out/subspace_partition")

MERGE_INTERVAL = 3200
MERGE_START = 5000
UNIT_SIZE = 8

In [None]:
tokenizer = copy_transformer.tokenizer.SingleCharTokenizer(
    alphabet=VOCABULARY,
    bos_token=">",
    eos_token="<",
    unk_token="?",
    pad_token="_",
    name_or_path="custom",
    add_bos_token=True,
)

model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDING_DIM,
    d_head=EMBEDDING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)
model_state_dict_path = Path("out/copy_transformer.pt")

dataset = copy_transformer.data.InfinitePureRepeatingPatternDataset(
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)

subspace_partition_config = SubspacePartitionConfig(
    exp_name=EXPERIMENT_NAME,
    output_dir=OUTPUT_DIR,
    model_config=model_config,
    model_weights_path=model_state_dict_path,
    act_sites=["blocks.0.hook_resid_post", "blocks.1.hook_resid_post"],
    tokenizer=tokenizer,
    dataset=dataset,
    max_steps=MAX_STEPS,
    merge_interval=MERGE_INTERVAL,
    merge_start=MERGE_START,
    unit_size=UNIT_SIZE,
    search_steps=1
)

In [None]:
run_subspace_partition(subspace_partition_config)