In [None]:
from dictionary_learning.trainers.top_k import AutoEncoderTopK
from musicsae.nnsight_model import MusicGenLanguageModel, AutoProcessor
import torch
from utils import MODELS_DIR, OUTPUT_DATA_DIR, INPUT_DATA_DIR
import torchaudio
import nnsight
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import json

device = "cuda:1"
model_name = "facebook/musicgen-medium"
max_tokens = 200
base_dir = INPUT_DATA_DIR / "music-bench" / "datashare-instruments"
model_sr = 32000

In [None]:
nn_model = MusicGenLanguageModel(model_name, device_map=device)
processor = AutoProcessor.from_pretrained(model_name)
ae = AutoEncoderTopK.from_pretrained(
    MODELS_DIR / "medium-sae-trivial-medium-sae-ee3b" / "16" / "trainer_0" / "checkpoints" / "ae_71100.pt"
).to(device)
layer = nn_model.decoder.model.decoder.layers[16]
ds = load_dataset("amaai-lab/MusicBench", split="train", streaming=True)

In [None]:
def sanity():
    batch = next(ds.iter(10))

    def forward_audio(batch):
        inputs = processor(
            audio=batch["audio_tensor"],
            sampling_rate=32000,
            text=batch["main_caption"],
            padding=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            with nn_model.trace(inputs, invoker_args={"truncation": True, "max_length": max_tokens}):
                return layer.output[0].save()

    act = forward_audio(batch)
    z = ae.encode(act)
    z.shape

In [None]:
from typing import Dict, List, Tuple

THETA_MIN = 0.01  # ri lower bound (exclusive)
THETA_MAX = 0.25  # ri upper bound (inclusive)
ACT_THRESHOLD = 0.0  # τ – any mean activation > 0 counts as “present”
TOP_K_EXAMPLES = 10


def compute_mean_activation(z: torch.Tensor) -> torch.Tensor:
    return z.mean(dim=1)  # average over time dimension T


def update_corpus_statistics(
    batch_mean_act: torch.Tensor,
    track_ids: torch.Tensor,
    sum_delta: torch.Tensor,
    example_scores: Dict[int, List[float]],
    example_ids: Dict[int, List[int]],
    active_tracks: Dict[int, set],
):
    """Accumulate global stats and keep example + track‑lists.

    Args:
        batch_mean_act: (B, F) mean activations for current batch.
        track_ids:       (B,) global ids for each track in batch.
        sum_delta:       (F,) running count of tracks where feature is active.
        example_scores/example_ids: top‑K maintenance buffers.
        active_tracks:   mapping feature → *set* of all track ids where feature active.
    """
    # δ_{i,j} indicator mask: 1 if mean act > τ
    delta = batch_mean_act > ACT_THRESHOLD
    sum_delta += delta.float().sum(dim=0)

    # For each track collect active features once (vectorised)
    B, F = batch_mean_act.shape
    for b in range(B):
        tid = int(track_ids[b])
        row = batch_mean_act[b]
        act_mask = row > ACT_THRESHOLD
        # Keep set of *all* active tracks per feature
        active_idx = torch.nonzero(act_mask, as_tuple=False).flatten().tolist()
        for feat in active_idx:
            active_tracks.setdefault(feat, set()).add(tid)

        # Maintain top‑K heaps
        scores_sorted = row.topk(min(TOP_K_EXAMPLES, F)).indices.tolist()
        for feat in scores_sorted:
            score = float(row[feat])
            if score <= ACT_THRESHOLD:
                continue
            buf_scores = example_scores.setdefault(feat, [])
            buf_ids = example_ids.setdefault(feat, [])
            insert_pos = next((i for i, s in enumerate(buf_scores) if score > s), len(buf_scores))
            buf_scores.insert(insert_pos, score)
            buf_ids.insert(insert_pos, tid)
            if len(buf_scores) > TOP_K_EXAMPLES:
                buf_scores.pop()
                buf_ids.pop()


def process_batch(batch, base_dir: Path, model_sr: int):
    def load_audio(base_dir, location, model_sr):
        audio_tensor, sr = torchaudio.load(str(base_dir / location).replace(".wav", ".mp3"))
        transform = torchaudio.transforms.Resample(sr, model_sr)
        return transform(audio_tensor).numpy()[0]

    audio_tensor = []
    caption = []
    location = []
    for row, cap in zip(batch["location"], batch["main_caption"]):
        if "data_aug2" in row:
            continue
        try:
            audio_tensor.append(load_audio(base_dir, row, model_sr))
        except Exception:
            continue
        caption.append(cap)
        location.append(row)
    return {"main_caption": caption, "audio_tensor": audio_tensor, "location": location}


def analyse_dataset(
    ds,
    processor,
    nn_model,
    layer,
    ae,
    batch_size: int = 10,
    max_tracks: int = 100,
    max_tokens: int = 1024,
    device: str | torch.device = "cuda",
) -> Tuple[pd.DataFrame, pd.DataFrame, List[str], Dict[str, List[int]]]:
    """Full end‑to‑end analysis.

    Returns
    -------
    mean_df         : (tracks × features) table of µ_{i,j} (may be huge!)
    corpus_df       : per‑feature table with activation rate & keep flag
    kept_features   : list[str] feature names kept after filtering
    tracks_per_feat : mapping kept feature → *list* of ALL track ids where it is active
    """
    num_features: int = ae.encoder.out_features
    # Running aggregates
    sum_delta = torch.zeros(num_features, dtype=torch.float32, device=device)
    example_scores: Dict[int, List[float]] = {}
    example_ids: Dict[int, List[int]] = {}
    active_tracks: Dict[int, set] = {}
    track_id_to_loc: Dict[int, str] = {}
    mean_rows = []  # will collect tensors row‑wise → concatenate
    mean_index = []

    iterator = ds.iter(batch_size)
    global_track_id = 0
    processed_tracks = 0
    for final_batch in tqdm(iterator, desc="Analysing dataset"):
        batch = process_batch(final_batch, base_dir, model_sr)
        B = len(batch["audio_tensor"])
        if B <= 0:
            continue
        track_ids = torch.arange(global_track_id, global_track_id + B)
        global_track_id += B
        for i in range(B):
            track_id_to_loc[int(track_ids[i])] = batch["location"][i]
        # Forward
        inputs = processor(
            audio=batch["audio_tensor"],
            sampling_rate=model_sr,
            text=batch["main_caption"],
            padding=True,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            with nn_model.trace(inputs, invoker_args={"truncation": True, "max_length": max_tokens}):
                act = layer.output[0].save()
            z = ae.encode(act)  # (B, T, F)

        batch_mean_act = compute_mean_activation(z)  # (B, F)
        mean_rows.append(batch_mean_act.cpu())
        mean_index.extend(track_ids.tolist())

        update_corpus_statistics(batch_mean_act, track_ids, sum_delta, example_scores, example_ids, active_tracks)

        processed_tracks += B
        if processed_tracks > max_tracks:
            break

    # ── Assemble µ_{i,j} big matrix
    mean_tensor = torch.cat(mean_rows, dim=0)
    feature_cols = [f"f{idx:04d}" for idx in range(num_features)]
    mean_df = pd.DataFrame(mean_tensor.numpy(), index=mean_index, columns=feature_cols)

    # ── Corpus‑level activation rate r_i
    n_tracks = len(mean_df)
    r_i = (sum_delta.detach().cpu() / n_tracks).numpy()

    corpus_df = pd.DataFrame(
        {
            "feature": feature_cols,
            "activation_rate": r_i,
        }
    )
    corpus_df["kept"] = (corpus_df.activation_rate > THETA_MIN) & (corpus_df.activation_rate <= THETA_MAX)

    kept_features = corpus_df[corpus_df.kept].feature.tolist()

    # ── Build mapping: kept feature → ALL active track ids
    tracks_per_feat: Dict[str, List[str]] = {}
    for feat_idx, track_set in active_tracks.items():
        feat_name = feature_cols[feat_idx]
        if feat_name not in kept_features:
            continue
        top_ids = mean_df.loc[list(track_set), feat_name].nlargest(TOP_K_EXAMPLES).index.tolist()
        tracks_per_feat[feat_name] = [track_id_to_loc[tid] for tid in top_ids]

    return mean_df, corpus_df, kept_features, tracks_per_feat


feature_stats, corpus_df, kept_features, example_dict = analyse_dataset(
    ds, processor, nn_model, layer, ae, batch_size=15, max_tracks=10000, device=device
)

In [None]:
with open(INPUT_DATA_DIR / "interp" / "features.json", "w") as fh:
    json.dump({k: list(set(v)) for k, v in example_dict.items()}, fh, indent=4)

In [None]:
from IPython.display import Audio, display

with open(INPUT_DATA_DIR / "interp" / "features.json", "r") as fh:
    feat = json.load(fh)

key = "f1595"
for p in set(feat[key]):
    display(Audio(str(Path(str(base_dir).replace("-instruments", "")) / p)))
    # display(Audio(str(base_dir / p).replace('.wav', '.mp3')))

In [None]:
prompt = "Prominent use of string instruments like violins, guitars, and other stringed instruments, often with melodic roles."
# prompt = "Lively Scottish bagpipe tune in G mixolydian mode, featuring a continuous drone, traditional grace notes, and a strong, march-like rhythm"
prompt = "This is an alternative rock song with slow tempo and guitar and drums"
tokens = 255
n = 5
# set_seed(42)
with nn_model.generate([prompt] * n, max_new_tokens=tokens):
    outputs = nnsight.list().save()  # Initialize & .save() nnsight list
    act = nnsight.list().save()
    for i in range(tokens):
        # set_seed(42)
        if i % 5 == 0:
            z = ae.encode(layer.output[0][:], use_threshold=True)
            z[:, :, 6140] = -9
            # act.append(z[:, :, 4881].detach().clone().cpu())
            for f in [367, 1145, 1129, 3444, 1911, 6140, 5775, 1556]:
                # z[:,:,f]=0.5
                ...
            layer.output[0][:] = ae.decode(z)
        outputs.append(nn_model.generator.output)
        nn_model.next()
for i in range(n):
    torchaudio.save(
        OUTPUT_DATA_DIR / "musicgen-sae" / f"out_{i}.wav",
        src=outputs[0][i].detach().cpu(),
        sample_rate=nn_model.config.sampling_rate,
        channels_first=True,
    )

In [None]:
len(act)