In [55]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, AutoConfig
from nnsight import LanguageModel
import nnsight
from IPython.display import clear_output
import torch

In [41]:
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
# model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to('cuda')
cfg = AutoConfig.from_pretrained("facebook/musicgen-small")

In [42]:
class MusicGenLanguageModel(LanguageModel):
    def _load_meta(
        self,
        repo_id: str,
        tokenizer_kwargs={},
        **kwargs,
    ):
        self.repo_id = repo_id

        self._load_config(repo_id, **kwargs)

        self._load_tokenizer(repo_id, **tokenizer_kwargs)
        return MusicgenForConditionalGeneration.from_pretrained(repo_id)

    def _load(
        self,
        repo_id: str,
        tokenizer_kwargs={},
        **kwargs,
    ):
        self.repo_id = repo_id

        self._load_config(repo_id, **kwargs)

        self._load_tokenizer(repo_id, **tokenizer_kwargs)
        return MusicgenForConditionalGeneration.from_pretrained(repo_id).to(kwargs["device_map"])


nn_model = MusicGenLanguageModel(
    "facebook/musicgen-small",
    config=cfg,
    tokenizer=processor.tokenizer,
    device_map="cuda",
)
with nn_model.generate("Hello world!", max_new_tokens=10):
    ...

clear_output()

In [71]:
inputs = processor(
    text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
    padding=True,
    return_tensors="pt",
)

pad_token_id = nn_model.generation_config.pad_token_id
decoder_input_ids = (
    torch.ones(
        (inputs.input_ids.shape[0] * nn_model.decoder.num_codebooks, 254), dtype=torch.long, device=nn_model.device
    )
    * pad_token_id
)

layer = nn_model.decoder.model.decoder.layers[16]
with nn_model.trace({k: v.to(nn_model.device) for k, v in inputs.items()}, decoder_input_ids=decoder_input_ids):
    out = layer.output.save()
    ...
out[0].shape

torch.Size([508, 1024])

In [7]:
tokens = 255
prompt = "Recreate the essence of a classic video game theme with chiptune sounds and nostalgic melodies."
for n in [2, 8]:
    ablate_layer = nn_model.decoder.model.decoder.layers[n]
    with nn_model.generate([prompt] * 3, max_new_tokens=tokens):
        outputs = nnsight.list().save()  # Initialize & .save() nnsight list
        for _ in range(tokens):
            ablate_layer.output[0][:] = ablate_layer.input[0][:]
            outputs.append(nn_model.generator.output)
            nn_model.next()
    # torchaudio.save(
    #     f"out_{n}.wav",
    #     src=outputs[0][0].detach().cpu(),
    #     sample_rate=nn_model.config.sampling_rate,
    #     channels_first=True,
    # )

In [9]:
outputs[0].shape

torch.Size([3, 1, 161280])

In [45]:
nn_model.device

device(type='cuda', index=0)