# Feature Exploration Notebook

This notebook walks through a workflow for exploring SAE features:
1. Environment & imports
2. Loading models and datasets
3. Selecting features of interest
4. Analyzing activations
5. Visualizing results
6. Inspecting token-level context

# Feature exploration

## 1. Set up

In [2]:
import torch
import sys
from pathlib import Path
from tqdm import tqdm
import heapq
import json

module_path = Path("..").resolve()
sys.path.append(str(module_path))

from music_interpret.modeling.sae import SparseAutoencoder
from music_interpret.activations.dataset import ActivationDataset
from music_interpret.model_adapters.mmt_adapter import MMTAdapter

## 2. Load sae and dataset

### Set paths to resources that are to be explored
- ckpt_path - path to the sae checkpoint 
- activ_root - path to shard*.npy files containing activations
- meta_root - path to shard*.npy files that allow to tranlate activations to music tokens

In [3]:
ckpt_path = Path("/home/albert-torzewski/PycharmProjects/Start/reports/sanity/diff_ini/sae_ckpt_001.pt")
activ_root = Path("/home/albert-torzewski/PycharmProjects/Start/data/activations/lmd_ape_residual/layer_007/processed") 
meta_root  = Path("/home/albert-torzewski/PycharmProjects/Start/data/activations/lmd_ape_residual/layer_007/meta")

### Load SAE checkpoint

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt = torch.load(ckpt_path, map_location=DEVICE)

sae = SparseAutoencoder(
    input_dim=ckpt["input_dim"],
    latent_dim=ckpt["latent_dim"],
    normalize_decoder=True,
).to(DEVICE)

sae.load_state_dict(ckpt["model_state"])
sae.eval()

SparseAutoencoder(
  (encoder): Linear(in_features=512, out_features=2048, bias=True)
  (decoder): Linear(in_features=2048, out_features=512, bias=False)
  (activation): ReLU()
)

### Load activation dataset

In [76]:
dataset = ActivationDataset(
    activ_root=activ_root,
    meta_root=meta_root,
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4096,
    num_workers=4,
)

## 3. Feature categorisation based on firing rates

In [6]:
latent_dim = sae.latent_dim

fire_counts = torch.zeros(latent_dim, dtype=torch.long)
total_tokens = 0

with torch.no_grad():
    for batch in loader:
        x = batch["activations"].to(DEVICE)
        z = sae.encode(x)
        fire_mask = z > 0
        fire_counts += fire_mask.sum(dim=0).cpu()
        total_tokens += z.size(0)

freqs = fire_counts.float() / total_tokens

In [9]:
print("Total tokens processed:", total_tokens)
print("Num features:", latent_dim)
print("Dead features:", (freqs == 0).sum().item())
print("Overactive features (>10%):", (freqs > 0.1).sum().item())

candidates = torch.where((freqs > 1e-4) & (freqs < 5e-2))[0]

print("First 20 candidate features:", candidates[:20].tolist())

Total tokens processed: 13166182
Num features: 2048
Dead features: 102
Overactive features (>10%): 1051
First 20 candidate features: [2, 9, 12, 21, 24, 30, 38, 41, 50, 63, 86, 88, 112, 144, 152, 177, 205, 235, 247, 262]


### Choose latent to inspect
- feature_pos - index number of latnet dictiornary param to inspect
- examples_to_display - number of best 'k' tokens to save for later visualisation

In [92]:
feature_pos = 112
examples_to_display = 100

### Select top k most activating tokens

In [93]:
topk = []

with torch.no_grad():
    for batch in tqdm(loader, desc=f"Collecting top activations for feature {feature_pos}"):
        x = batch["activations"].to(DEVICE)
        z = sae.encode(x)

        vals = z[:, feature_pos].cpu()
        sample_idx = batch.get("sample_idx")
        token_pos  = batch.get("token_pos")

        for i in range(len(vals)):
            v = float(vals[i])
            if v <= 0:
                continue

            item = (
                v,
                int(sample_idx[i]),
                int(token_pos[i]),
            )

            if len(topk) < examples_to_display:
                heapq.heappush(topk, item)
            else:
                if v > topk[0][0]:
                    heapq.heapreplace(topk, item)

topk_sorted = sorted(topk, key=lambda x: -x[0])

Collecting top activations for feature 112: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3215/3215 [01:05<00:00, 49.33it/s]


### TopK print

In [94]:
print("Top 10 activations:")
for r in topk_sorted[:10]:
    print(r)

Top 10 activations:
(11.125716209411621, 1569, 13)
(10.652200698852539, 7283, 15)
(10.577179908752441, 1809, 12)
(10.44907283782959, 12015, 8)
(10.295289993286133, 5774, 7)
(10.030010223388672, 7307, 7)
(10.018009185791016, 12050, 9)
(9.991106986999512, 6754, 22)
(9.952404975891113, 11422, 7)
(9.88181209564209, 11276, 9)


### Save tokens for later use

In [95]:
out_path = Path(f"feature_{feature_pos}_top{examples_to_display}.json")

records = [
    {
        "activation": v,
        "sample_idx": s,
        "token_pos": p,
    }
    for (v, s, p) in topk_sorted
]

with open(out_path, "w") as f:
    json.dump(records, f, indent=2)

print("Saved to:", out_path)


Saved to: feature_112_top100.json


In [96]:
print("Top 10 activations:")
for r in records[:10]:
    print(r)

Top 10 activations:
{'activation': 11.125716209411621, 'sample_idx': 1569, 'token_pos': 13}
{'activation': 10.652200698852539, 'sample_idx': 7283, 'token_pos': 15}
{'activation': 10.577179908752441, 'sample_idx': 1809, 'token_pos': 12}
{'activation': 10.44907283782959, 'sample_idx': 12015, 'token_pos': 8}
{'activation': 10.295289993286133, 'sample_idx': 5774, 'token_pos': 7}
{'activation': 10.030010223388672, 'sample_idx': 7307, 'token_pos': 7}
{'activation': 10.018009185791016, 'sample_idx': 12050, 'token_pos': 9}
{'activation': 9.991106986999512, 'sample_idx': 6754, 'token_pos': 22}
{'activation': 9.952404975891113, 'sample_idx': 11422, 'token_pos': 7}
{'activation': 9.88181209564209, 'sample_idx': 11276, 'token_pos': 9}


## 4. Data visualisation

### Optional records load

In [102]:
records_path = "feature_2_top200.json"

with open(records_path) as f:
    records = json.load(f)

### Load original dataset and mmt helper

In [78]:
adapter = MMTAdapter(dataset_name="lmd", data_repr="ape")
mmt_loader = adapter.create_dataloader(batch_size=1, shuffle=False)

data = mmt_loader.dataset

### Define helper print functions

In [65]:
def get_context_tokens(dataset, sample_idx, pos, window=8):
    """
    Returns:
        tokens  : list of encoded tokens
        center  : index of center token in this window
    """
    sample = dataset[sample_idx][0]
    seq = sample["seq"]

    start = max(0, pos - window)
    end   = min(len(seq), pos + window + 1)

    tokens = seq[start:end]
    center = pos - start
    return tokens, center


def print_context(adapter, dataset, sample_idx, pos, window=8):
    tokens, center = get_context_tokens(dataset, sample_idx, pos, window)

    text = adapter.decode_notes(tokens)

    lines = text.splitlines()

    print("=" * 80)
    print(f"sample_idx={sample_idx} pos={pos}")
    print("-" * 80)

    for i, line in enumerate(lines):
        if i == center:
            print(f">>> {line}  <<<")
        else:
            print(f"    {line}")

    print("-" * 80)

### Select display options
- window - how many tokens before and after should be displayed
- n_show - how many features to print

In [87]:
window = 4
n_show = 100

### Show tokens in text format

In [103]:
for r in records[:n_show]:
    print_context(
        adapter,
        data,
        sample_idx=r["sample_idx"],
        pos=r["token_pos"],
        window=window,
    )

sample_idx=3750 pos=5
--------------------------------------------------------------------------------
    Instrument: piano
    Instrument: organ
    Instrument: synth-bass
    Instrument: brasses
>>> Instrument: synth-brasses  <<<
    Instrument: lead
    Instrument: synth-drums
    Start of notes
    Note: beat=0, position=0, pitch=D#4, duration=1, instrument=lead
--------------------------------------------------------------------------------
sample_idx=9113 pos=5
--------------------------------------------------------------------------------
    Instrument: organ
    Instrument: steel-string-guitar
    Instrument: synth-bass
    Instrument: strings
>>> Instrument: synth-brasses  <<<
    Instrument: lead
    Start of notes
    Note: beat=0, position=11, pitch=E4, duration=72, instrument=strings
    Note: beat=0, position=11, pitch=G4, duration=72, instrument=strings
--------------------------------------------------------------------------------
sample_idx=8045 pos=7
-------------