In [2]:
from dictionary_learning.trainers.top_k import AutoEncoderTopK
from musicsae.nnsight_model import MusicGenLanguageModel, AutoProcessor
from utils import MODELS_DIR, INPUT_DATA_DIR
import torchaudio
from datasets import load_dataset

In [3]:
device = "cuda:0"
model_name = "facebook/musicgen-medium"  # can be any Huggingface model

model = MusicGenLanguageModel(model_name, device_map=device)
submodule = model.decoder.model.decoder.layers[16]
processor = AutoProcessor.from_pretrained(model_name)
ae = AutoEncoderTopK.from_pretrained(
    MODELS_DIR / "musicgen-sae-topk" / "16" / "trainer_0" / "checkpoints" / "ae_500.pt"
).to(device)

Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summ

In [None]:
def add_audio_to_sample(model_sr, sample):
    audio_path = INPUT_DATA_DIR / "music-bench" / "datashare" / sample["location"]
    audio_tensor, sr = torchaudio.load(str(audio_path))
    transform = torchaudio.transforms.Resample(sr, model_sr)
    sample["audio_tensor"] = transform(audio_tensor).numpy()[0]
    sample["sr"] = model_sr
    return sample


ds = (
    load_dataset("amaai-lab/MusicBench", split="test")
    .select(range(5))
    .map(lambda x: add_audio_to_sample(32000, x))
    .select_columns(["main_caption", "audio_tensor", "sr"])
    .iter(batch_size=2)
)
batch = next(ds)
inputs = processor(
    audio=batch["audio_tensor"],
    sampling_rate=32000,
    text=batch["main_caption"],
    padding=True,
    return_tensors="pt",
)

In [None]:
with model.trace(inputs, invoker_args={"truncation": True, "max_length": 10}):
    logits_original = model.output.save()
logits_original

In [None]:
# loss_recovered(
#     inputs,
#     model,
#     submodule,
#     ae,
# )