In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import torchaudio

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to('cuda')

Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summ

In [3]:
audio, sample_rate = torchaudio.load("/home/mszawerda/music-sae/dependencies/musicgen/example/dataset/audio/electro_1.wav")
sample = {
    "array": torchaudio.functional.resample(audio, sample_rate, 32000),
    "sampling_rate": sample_rate
}

sample["array"] = sample["array"][0]
inputs = processor(
    audio=sample["array"],
    sampling_rate=sample["sampling_rate"],
    text=["80s blues track with groovy saxophone"]*3,
    padding=True,
    return_tensors="pt",
)

In [4]:
model.get_submodule('decoder.model.decoder.layers')

ModuleList(
  (0-23): 24 x MusicgenDecoderLayer(
    (self_attn): MusicgenSdpaAttention(
      (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
    )
    (activation_fn): GELUActivation()
    (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder_attn): MusicgenSdpaAttention(
      (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
      (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
    )
    (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=1024, out_features=4096, bias=False)
  

In [5]:
inputs = processor(
    text=["80s pop track with bassy drums and synth"],
    padding=True,
    return_tensors="pt",
)

n = 12
ablate_hook = model.get_submodule(f'decoder.model.decoder.layers.{n}.fc2')
@torch.no_grad()
def perform_ablation(module, inputs, outputs):
    return torch.zeros_like(outputs)

hook = ablate_hook.register_forward_hook(perform_ablation)
with torch.no_grad():
    audio_values = model.generate(**{k:v.to('cuda') for k,v in inputs.items()}, do_sample=True, guidance_scale=3, max_new_tokens=256)
hook.remove()



In [6]:
# hook_point = model.get_submodule('decoder.model.decoder.layers.12.encoder_attn.out_proj')
# activations = []
# def perform_sae(module, input, output):
#     activations.append((input, output))
# hook_point.register_forward_hook(perform_sae)
# with torch.no_grad():
#     model(**{k: v.to('cuda') for k,v in inputs.items()})
# len(activations)
# input, output = activations[0]