In [1]:
from dictionary_learning.trainers.top_k import AutoEncoderTopK
from yue.yue import YuEInferenceConfig, YuEProcessorConfig, YuEProcessor
from yue.common import BlockTokenRangeProcessor, load_tags, filter_tags, get_instrumental_only_lyrics
from transformers import AutoModelForCausalLM, LogitsProcessorList
from nnsight import LanguageModel
import torch
from src.project_config import MODELS_DIR, INPUT_DATA_DIR
import torchaudio
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import json

device = "cuda:0"
device_ae = "cpu"
model_name = "7B-anneal-en-icl"
max_tokens = 200
base_dir = INPUT_DATA_DIR / "music-bench" / "datashare-instruments"
vocals_dir = INPUT_DATA_DIR / "music-bench" / "datashare-vocals"
instruments_dir = INPUT_DATA_DIR / "music-bench" / "datashare-instruments"
model_sr = 16000

2025-06-02 17:51:34,154 INFO PyTorch version 2.5.1+cu121 available.


In [2]:
processor = YuEProcessor(
    device,
    YuEProcessorConfig(
        codec_parent_path="./dependencies", tokenizer_model="./models/mm_tokenizer_v0.2_hf/tokenizer.model"
    ),
)
model = AutoModelForCausalLM.from_pretrained(
    f"m-a-p/YuE-s1-{model_name}",
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2"
)
model = LanguageModel(model, input_names=["input_ids"])
model.to(device)
model.eval()


ae = AutoEncoderTopK.from_pretrained(MODELS_DIR / "yue" / "yue.pt").to(device_ae)
layer = model.model.layers[15]
ds = load_dataset("amaai-lab/MusicBench", split="train", streaming=True)

  WeightNorm.apply(module, name, dim)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  state_dict = t.load(path)


In [None]:
def sanity():
    batch = next(ds.iter(10))

    def forward_audio(batch):
        inputs = processor(
            audio=batch["audio_tensor"],
            sampling_rate=32000,
            text=batch["main_caption"],
            padding=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            with model.trace(inputs, invoker_args={"truncation": True, "max_length": max_tokens}):
                return layer.output[0].save()

    act = forward_audio(batch)
    z = ae.encode(act)
    z.shape

In [4]:
from typing import Dict, List, Tuple

THETA_MIN = 0.01  # ri lower bound (exclusive)
THETA_MAX = 0.25  # ri upper bound (inclusive)
ACT_THRESHOLD = 0.0  # τ – any mean activation > 0 counts as “present”
TOP_K_EXAMPLES = 10
activation_dim = 4096


def compute_mean_activation(z: torch.Tensor) -> torch.Tensor:
    return z.mean(dim=1)  # average over time dimension T


def update_corpus_statistics(
    batch_mean_act: torch.Tensor,
    track_ids: torch.Tensor,
    sum_delta: torch.Tensor,
    example_scores: Dict[int, List[float]],
    example_ids: Dict[int, List[int]],
    active_tracks: Dict[int, set],
):
    """Accumulate global stats and keep example + track‑lists.

    Args:
        batch_mean_act: (B, F) mean activations for current batch.
        track_ids:       (B,) global ids for each track in batch.
        sum_delta:       (F,) running count of tracks where feature is active.
        example_scores/example_ids: top‑K maintenance buffers.
        active_tracks:   mapping feature → *set* of all track ids where feature active.
    """
    # δ_{i,j} indicator mask: 1 if mean act > τ
    delta = batch_mean_act > ACT_THRESHOLD
    sum_delta += delta.float().sum(dim=0)

    # For each track collect active features once (vectorised)
    B, F = batch_mean_act.shape
    for b in range(B):
        tid = int(track_ids[b])
        row = batch_mean_act[b]
        act_mask = row > ACT_THRESHOLD
        # Keep set of *all* active tracks per feature
        active_idx = torch.nonzero(act_mask, as_tuple=False).flatten().tolist()
        for feat in active_idx:
            active_tracks.setdefault(feat, set()).add(tid)

        # Maintain top‑K heaps
        scores_sorted = row.topk(min(TOP_K_EXAMPLES, F)).indices.tolist()
        for feat in scores_sorted:
            score = float(row[feat])
            if score <= ACT_THRESHOLD:
                continue
            buf_scores = example_scores.setdefault(feat, [])
            buf_ids = example_ids.setdefault(feat, [])
            insert_pos = next((i for i, s in enumerate(buf_scores) if score > s), len(buf_scores))
            buf_scores.insert(insert_pos, score)
            buf_ids.insert(insert_pos, tid)
            if len(buf_scores) > TOP_K_EXAMPLES:
                buf_scores.pop()
                buf_ids.pop()


def process_batch(batch, base_dir: Path, model_sr: int):
    def load_audio(base_dir, location, model_sr):
        audio_tensor, sr = torchaudio.load(
            str(base_dir / location).replace(".wav", ".mp3"),
        )
        transform = torchaudio.transforms.Resample(sr, model_sr)
        return transform(audio_tensor).numpy()[0]

    vocals_tensor = []
    instruments_tensor = []
    caption = []
    location = []
    for row, cap in zip(batch["location"], batch["main_caption"]):
        if "data_aug2" in row:
            continue
        try:
            vocals_tensor.append(load_audio(vocals_dir, row, model_sr))
            instruments_tensor.append(load_audio(instruments_dir, row, model_sr))
        except Exception:
            continue
        caption.append(cap)
        location.append(row)
    return {
        "main_caption": caption,
        "vocals_tensor": vocals_tensor,
        "instruments_tensor": instruments_tensor,
        "location": location,
    }


def analyse_dataset(
    ds,
    processor,
    nn_model,
    layer,
    ae,
    batch_size: int = 10,
    max_tracks: int = 100,
    max_tokens: int = 1024,
    device: str | torch.device = "cuda",
) -> Tuple[pd.DataFrame, pd.DataFrame, List[str], Dict[str, List[int]]]:
    """Full end‑to‑end analysis.

    Returns
    -------
    mean_df         : (tracks × features) table of µ_{i,j} (may be huge!)
    corpus_df       : per‑feature table with activation rate & keep flag
    kept_features   : list[str] feature names kept after filtering
    tracks_per_feat : mapping kept feature → *list* of ALL track ids where it is active
    """
    num_features: int = ae.encoder.out_features
    # Running aggregates
    sum_delta = torch.zeros(num_features, dtype=torch.float32, device=device)
    example_scores: Dict[int, List[float]] = {}
    example_ids: Dict[int, List[int]] = {}
    active_tracks: Dict[int, set] = {}
    track_id_to_loc: Dict[int, str] = {}
    mean_rows = []  # will collect tensors row‑wise → concatenate
    mean_index = []

    iterator = ds.iter(batch_size)
    global_track_id = 0
    processed_tracks = 0

    tags = load_tags()
    args_inference = YuEInferenceConfig()

    for final_batch in tqdm(iterator, desc="Analysing dataset"):
        batch = process_batch(final_batch, base_dir, model_sr)
        B = len(batch["vocals_tensor"])
        if B <= 0:
            continue

        track_ids = torch.arange(global_track_id, global_track_id + B)
        global_track_id += B
        for i in range(B):
            track_id_to_loc[int(track_ids[i])] = batch["location"][i]

        def forward_audio(batch):
            vocals = torch.tensor(batch["vocals_tensor"][0], dtype=torch.float32)
            vocals = vocals.unsqueeze(0)
            instruments = torch.tensor(batch["instruments_tensor"][0], dtype=torch.float32)
            instruments = instruments.unsqueeze(0)

            genres = filter_tags(tags, batch["main_caption"][0])
            lyrics = get_instrumental_only_lyrics()
            inputs, begin, end = processor.process_trace(genres, lyrics, vocals, instruments)

            with model.trace(
                inputs=inputs,
                max_new_tokens=max_tokens,
                min_new_tokens=max_tokens,
                do_sample=True,
                top_p=args_inference.top_p,
                temperature=args_inference.temperature,
                repetition_penalty=args_inference.repetition_penalty,
                eos_token_id=processor.eoa,
                pad_token_id=processor.eoa,
                logits_processor=LogitsProcessorList(
                    [BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]
                ),
                guidance_scale=args_inference.guidance_scale,
                invoker_args={"truncation": True, "max_length": max_tokens},
            ):
                return layer.output[0].save(), begin, end

        act, begin, end = forward_audio(batch)
        act = act.view(-1, activation_dim).detach().cpu()
        act = act[begin + 1 : end : 2, :]
        act = act.unsqueeze(0)

        z = ae.encode(act.to(device))

        batch_mean_act = compute_mean_activation(z)  # (B, F)
        mean_rows.append(batch_mean_act.cpu())
        mean_index.extend(track_ids.tolist())

        update_corpus_statistics(batch_mean_act, track_ids, sum_delta, example_scores, example_ids, active_tracks)

        processed_tracks += B
        if processed_tracks > max_tracks:
            break

    # ── Assemble µ_{i,j} big matrix
    mean_tensor = torch.cat(mean_rows, dim=0)
    feature_cols = [f"f{idx:04d}" for idx in range(num_features)]
    mean_df = pd.DataFrame(mean_tensor.detach().numpy(), index=mean_index, columns=feature_cols)

    # ── Corpus‑level activation rate r_i
    n_tracks = len(mean_df)
    r_i = (sum_delta.detach().cpu() / n_tracks).numpy()

    corpus_df = pd.DataFrame(
        {
            "feature": feature_cols,
            "activation_rate": r_i,
        }
    )
    corpus_df["kept"] = (corpus_df.activation_rate > THETA_MIN) & (corpus_df.activation_rate <= THETA_MAX)

    kept_features = corpus_df[corpus_df.kept].feature.tolist()

    # ── Build mapping: kept feature → ALL active track ids
    tracks_per_feat: Dict[str, List[str]] = {}
    for feat_idx, track_set in active_tracks.items():
        feat_name = feature_cols[feat_idx]
        if feat_name not in kept_features:
            continue
        top_ids = mean_df.loc[list(track_set), feat_name].nlargest(TOP_K_EXAMPLES).index.tolist()
        tracks_per_feat[feat_name] = [track_id_to_loc[tid] for tid in top_ids]

    return mean_df, corpus_df, kept_features, tracks_per_feat


feature_stats, corpus_df, kept_features, example_dict = analyse_dataset(
    ds, processor, model, layer, ae, batch_size=1, max_tracks=2000, device=device_ae
)

Analysing dataset: 39532it [08:19, 79.12it/s]  


In [5]:
with open(INPUT_DATA_DIR / "interp" / "features.json", "w") as fh:
    json.dump({k: list(set(v)) for k, v in example_dict.items()}, fh, indent=4)

In [6]:
from IPython.display import Audio, display

with open(INPUT_DATA_DIR / "interp" / "features.json", "r") as fh:
    feat = json.load(fh)

key = "f1595"
for p in set(feat[key]):
    display(Audio(str(Path(str(base_dir).replace("-instruments", "")) / p)))
    # display(Audio(str(base_dir / p).replace('.wav', '.mp3')))