# Sparse Autoencoder Training Demo

## Setup

In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

from hta.trace_analysis import TraceAnalysis
import pandas as pd
import torch
from torch.autograd.profiler_util import FunctionEventAvg
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
from transformers import PreTrainedTokenizerBase
import wandb

from sparse_autoencoder import SparseAutoencoder, TensorActivationStore, pipeline
from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset


os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# device = get_device()
device = torch.device("cpu")

### Source Model

In [4]:
src_model = HookedTransformer.from_pretrained("solu-1l", dtype="float32")
src_model.eval()
src_d_mlp: int = src_model.cfg.d_mlp  # type: ignore
src_d_mlp

Loaded pretrained model solu-1l into HookedTransformer


2048

### Source Dataset

In [5]:
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = PileUncopyrightedDataset(tokenizer=tokenizer)

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

### Activation Store

In [6]:
max_items = 100_000
store = TensorActivationStore(max_items, src_d_mlp, device)

### Autoencoder

In [7]:
src_d_mlp

2048

In [8]:
autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 4, torch.zeros(src_d_mlp))

## Training

If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to wandb.

In [9]:
# wandb.init(project="sparse-autoencoder", dir=".cache/wandb")

In [20]:
export_dir = Path("./.profile")
trace_handler = tensorboard_trace_handler(dir_name=str(export_dir))

with profile(
    activities=[
        ProfilerActivity.CPU,
    ],
    # on_trace_ready=trace_handler,
    profile_memory=True,
    record_shapes=True,
    with_stack=True,
) as profiler:
    pipeline(
        src_model=src_model,
        src_model_activation_hook_point="blocks.0.mlp.hook_post",
        src_model_activation_layer=0,
        source_dataset=source_data,
        activation_store=store,
        num_activations_before_training=max_items,
        autoencoder=autoencoder,
        device=device,
        max_activations=100_000,
    )

STAGE:2023-11-13 07:51:24 60564:2479878 ActivityProfilerController.cpp:312] Completed Stage: Warm Up


Total activations trained on:   0%|          | 0/100000 [00:00<?, ?it/s, Current mode=initializing]

STAGE:2023-11-13 07:51:44 60564:2479878 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-11-13 07:51:45 60564:2479878 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


In [21]:
train_profile: profile = profiler

batches = 1
data = {}

for item in train_profile.key_averages():
    item: FunctionEventAvg = item
    name = item.key
    memory = item.self_cpu_memory_usage
    data[name] = memory

profile_dataframe = pd.Series(data)
profile_dataframe = profile_dataframe.sort_values(
    ascending=False
)  # Order by memory usage
profile_dataframe = profile_dataframe / 1024 / 1024  # Convert to MB
profile_dataframe = profile_dataframe / batches  # Average over batches
profile_dataframe = profile_dataframe.round(0)
profile_dataframe = profile_dataframe[
    profile_dataframe > 0
]  # Filter to positive only (not deallocations)
profile_dataframe.name = "Memory usage (MB per batch)"
total = profile_dataframe.sum()

store_memory = max_items * src_d_mlp * 4 / 1024 / 1024
print(f"Total used memory: {total:.0f} MB")
print(f"Store memory: {store_memory:.0f} MB")
print(f"Difference: {(total - store_memory/batches):.0f} MB")

# Add a percentage column
profile_dataframe = profile_dataframe.to_frame()
profile_dataframe["Percentage"] = profile_dataframe / total * 100
profile_dataframe["Percentage"] = profile_dataframe["Percentage"].round(0)

# Export
export_dir.mkdir(exist_ok=True)
chrome_trace_path = export_dir / "chrome-trace.json"
memory_timeline_path = export_dir / "memory-timeline.stacks"
chrome_trace_path.unlink(missing_ok=True)
memory_timeline_path.unlink(missing_ok=True)

# profiler.export_chrome_trace(str(chrome_trace_path))
profiler.export_memory_timeline(str(memory_timeline_path))


# # Print
profile_dataframe

Total used memory: 69491 MB
Store memory: 781 MB
Difference: 68710 MB


Unnamed: 0,Memory usage (MB per batch),Percentage
aten::div,7908.0,11.0
aten::mm,7888.0,11.0
aten::empty_strided,6542.0,9.0
aten::mul,5507.0,8.0
aten::sub,3944.0,6.0
aten::add,3455.0,5.0
aten::sqrt,3203.0,5.0
aten::clamp_min,3125.0,4.0
aten::addmm,3125.0,4.0
aten::threshold_backward,3125.0,4.0


In [22]:
# analyzer = TraceAnalysis(trace_dir=str(export_dir))

# Temporal breakdown
# temporal_breakdown_df = analyzer.get_temporal_breakdown()
# temporal_breakdown_df
# Idle time breakdown
# idle_time_df = analyzer.get_idle_time_breakdown()

# # Kernel breakdown
# kernel_breakdown_df = analyzer.get_gpu_kernel_breakdown()

# # Communication computation overlap
# comm_comp_overlap_df = analyzer.get_comm_comp_overlap()

# # Memory bandwidth time series
# memory_bw_series = analyzer.get_memory_bw_time_series()

# # Memory bandwidth summary
# memory_bw_summary = analyzer.get_memory_bw_summary()

2023-11-13 07:53:10,316 - hta - trace.py:L389 - INFO - .profile


2023-11-13 07:53:12,465 - hta - trace_file.py:L61 - ERROR - If the trace file does not have the rank specified in it, then add the following snippet key to the json files to use HTA; "distributedInfo": {"rank": 0}. If there are multiple traces files, then each file should have a unique rank value.
2023-11-13 07:53:12,465 - hta - trace_file.py:L94 - INFO - Rank to trace file map:
{0: '.profile/Alans-MacBook-Pro.local_60564.1699832504856162000.pt.trace.json'}
2023-11-13 07:53:12,466 - hta - trace.py:L535 - INFO - ranks=[0]
2023-11-13 07:53:23,793 - hta - trace.py:L118 - INFO - Parsed .profile/Alans-MacBook-Pro.local_60564.1699832504856162000.pt.trace.json time = 11.33 seconds 


KeyboardInterrupt: 