In [2]:
from dictionary_learning import ActivationBuffer
from dictionary_learning.trainers.top_k import AutoEncoderTopK, TopKTrainer
from dictionary_learning.training import trainSAE
from musicsae.nnsight_model import MusicGenLanguageModel, AutoProcessor
import torch as t
import gc
from utils import MODELS_DIR, OUTPUT_DATA_DIR
import torchaudio
import nnsight
from datasets import load_dataset, Dataset
from torch.utils.data import Dataset as TorchDataset

In [3]:
device = "cuda:0"
model_name = "facebook/musicgen-medium"  # can be any Huggingface model

model = MusicGenLanguageModel(model_name, device_map=device)
submodule = model.decoder.model.decoder.layers[16]
processor = AutoProcessor.from_pretrained(model_name)
activation_dim = 1024  # output dimension of the MLP
dictionary_size = 2 * activation_dim

Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summ

In [3]:
class PromptDataset(TorchDataset):
    def __init__(self, ds: Dataset):
        self.ds = ds

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return self.ds[idx]["main_caption"]


class PromtLoader:
    def __init__(self, ds: PromptDataset):
        self.ds = ds
        self.data_iter = iter(self.ds)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.ds)
            data = next(self.data_iter)
        return data

In [13]:
tokens = 255


class MusicActivationBuffer(ActivationBuffer):
    def refresh(self):
        gc.collect()
        t.cuda.empty_cache()
        self.activations = self.activations[~self.read]

        current_idx = len(self.activations)
        new_activations = t.empty(
            self.activation_buffer_size, self.d_submodule, device=self.device, dtype=self.model.dtype
        )

        new_activations[: len(self.activations)] = self.activations
        self.activations = new_activations
        while current_idx < self.activation_buffer_size:
            with t.no_grad():
                with self.model.generate(self.text_batch(), max_new_tokens=tokens):
                    activations = nnsight.list().save()
                    for _ in range(tokens):
                        hidden_states = self.submodule.output.save()
                        activations.append(hidden_states[0])
                        model.next()
            activations = t.cat(activations)
            remaining_space = self.activation_buffer_size - current_idx
            if remaining_space <= 0:
                break
            activations = activations[:remaining_space]
            self.activations[current_idx : current_idx + len(activations)] = activations.squeeze().to(self.device)
            current_idx += len(activations)
        self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device)


n = 100
prompts_ds = PromptDataset(
    load_dataset("amaai-lab/MusicBench", split="test").select_columns(["main_caption"]).select(range(n))
)
buffer = MusicActivationBuffer(
    data=PromtLoader(prompts_ds),
    model=model,
    submodule=submodule,
    d_submodule=activation_dim,
    n_ctxs=10,
    ctx_len=10,
    refresh_batch_size=10,
    out_batch_size=10,
    device=device,
)

In [14]:
trainer_cfg = {
    "trainer": TopKTrainer,
    "dict_class": AutoEncoderTopK,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
    "steps": 1000,
    "layer": 16,
    "lm_name": "MusicGen-small",
    "warmup_steps": 2,
    "k": 10,
}
#
# # train the sparse autoencoder (SAE)
trainSAE(
    data=buffer,  # you could also use another (i.e. pytorch dataloader) here instead of buffer
    trainer_configs=[trainer_cfg],
    steps=trainer_cfg["steps"],
    save_dir=MODELS_DIR,
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:48<00:00, 20.48it/s]


In [12]:
ae = AutoEncoderTopK.from_pretrained(MODELS_DIR / "musicgen-sae" / "3" / "trainer_0" / "ae.pt").to(device)
submodule = model.decoder.model.decoder.layers[3]

In [14]:
prompt = "Recreate the essence of a classic video game theme with chiptune sounds and nostalgic melodies."
tokens = 255
n = 3
with model.generate([prompt] * n, max_new_tokens=tokens):
    outputs = nnsight.list().save()  # Initialize & .save() nnsight list
    for _ in range(tokens):
        submodule.output[0][:] = ae(submodule.output[0][:])
        outputs.append(model.generator.output)
        model.next()
for i in range(n):
    torchaudio.save(
        OUTPUT_DATA_DIR / "musicgen-sae" / f"out_{i}.wav",
        src=outputs[0][i].detach().cpu(),
        sample_rate=model.config.sampling_rate,
        channels_first=True,
    )