In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch, os, glob, json, math
from torch import nn
import numpy as np
import onnxruntime as ort
import inspect

processor = AutoProcessor.from_pretrained("facebook/musicgen-stereo-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")

folder = './musicgen-stereo-small'
os.makedirs(folder, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


Test run throughputs

In [4]:
inputs = processor(
    text=["80s pop track with bassy drums and synth asd asd asd"],
    padding=True,
    return_tensors="pt",
)

res = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=259, temperature=2.0, top_k=500, top_p=0.0)

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'output_attentions': False, 'output_hidden_states': False, 'return_dict': True, 'input_ids': tensor([[ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,    11,
         13353,    38,    26,    38,    26,    38,    26,     1]])}
{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'use_cache': True, 'guidance_scale': 3, 'encoder_outputs': BaseModelOutput(last_hidden_state=tensor([[[ 0.2295, -0.0886, -0.0691,  ..., -0.0351, -0.9258, -0.2265],
         [ 0.1784,  0.0213, -0.1259,  ..., -0.0948, -0.3672, -0.3329],
         [-0.0491, -0.1878,  0.1695,  ..., -0.3521, -0.5222, -0.4096],
         ...,
         [-0.1734,  0.0202,  0.0772,  ...,  0.1520,  0.1643,  0.0042],
         [-0.5666,  0.1295, -0.0371,  ..., -0.0621, -0.1690,  0.2761],
         [ 0.0207,  0.0166, -0.0044,  ...,  0.0049,  0.0139,  0.

KeyboardInterrupt: 

#### Try and export the full model

Doesnt work bc of problems in attention layer

In [None]:
class MusicGenWrapper(nn.Module):
    def __init__(self, model: MusicgenForConditionalGeneration):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, guidance_scale=3, max_new_tokens=256, temperature=2.0, top_k=500, top_p=0.0):
        '''Taken from last section of the model'''

        inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }

        output_values = model.generate(**inputs, guidance_scale=guidance_scale.item(), max_new_tokens=max_new_tokens.item(), temperature=temperature.item(), top_k=top_k.item(), top_p=top_p.item())

        return output_values

In [None]:
musicgen_wrapper = MusicGenWrapper(model)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'}, # Allow variable batch size and sequence length
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, # Allow variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_guidance_scale = torch.tensor(3, dtype=torch.int64)
dummy_max_new_tokens = torch.tensor(256, dtype=torch.int64)
dummy_temperature = torch.tensor(2.0, dtype=torch.float32)
dummy_top_k = torch.tensor(500, dtype=torch.int64)
dummy_top_p = torch.tensor(0.0, dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    musicgen_wrapper,
    (inputs['input_ids'], inputs['attention_mask'], dummy_guidance_scale, dummy_max_new_tokens, dummy_temperature, dummy_top_k, dummy_top_p),
    f"{folder}/musicgen.onnx", 
    input_names=[
        'input_ids',
        'attention_mask',
        'guidance_scale',
        'max_new_tokens',
        'temperature',
        'top_k',
        'top_p'
    ],
    output_names=['output_values'],
    dynamic_axes=dynamic_axes
)

### Exporting the model
TODO:
- [x] Configuration lists exported
- [x] Text encoder exported
- [x] Projection layer exported
- [x] Decoder layer exported
- [ ] Sampling function exported
- [x] Output decoder exported
- [ ] Look at making layers efficient
- [ ] Full model throughput test

Export configs

In [None]:
processor.tokenizer.save_pretrained(f'{folder}')
processor.save_pretrained(f'{folder}')
model.config.to_json_file(f'{folder}/config.json')
model.generation_config.to_json_file(f'{folder}/generation_config.json')

#### Export text encoder

In [49]:
class TextEncoderWrapper(nn.Module):
    def __init__(self, text_encoder):
        super().__init__()
        self.text_encoder = text_encoder

    def forward(self, input_ids, attention_mask, cfg=None):
        last_hidden_state = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        if cfg is not None and cfg.item() > 1:
            last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
            res_attention_mask = torch.concatenate([attention_mask, torch.zeros_like(attention_mask)], dim=0)

        return last_hidden_state, res_attention_mask

In [57]:
text_encoder_wrapper = TextEncoderWrapper(model.text_encoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoded': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (1, 18), dtype=torch.int64)
dummy_attention_mask = torch.randint(0, 100, (1, 18), dtype=torch.int64)
dummy_cfg = torch.tensor([3], dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    text_encoder_wrapper,
    (dummy_input_ids, dummy_attention_mask, dummy_cfg),
    f"{folder}/text_encoder.onnx",
    input_names=['input_ids', 'attention_mask', 'cfg'],
    output_names=['last_hidden_state', 'res_attention_mask'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  if cfg is not None and cfg.item() > 1:


In [58]:
ort_session = ort.InferenceSession(f"{folder}/text_encoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
input_ids_np = inputs['input_ids'].detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = inputs['attention_mask'].detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    'input_ids': input_ids_np,
    'attention_mask': attention_mask_np,
    'cfg': torch.tensor([3], dtype=torch.int64).detach().numpy()
}
encoded = ort_session.run(None, ort_inputs)[0]

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid input name: cfg

#### Export the projection layer

In [None]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states_in': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states_out': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 12, 768), dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    model.enc_to_dec_proj,                             # Model to export
    (dummy_encoder_hidden_states,),                             # Example input tuple
    f"{folder}/enc_to_dec_proj.onnx",               # Export path
    input_names=['encoder_hidden_states_in'],          # Input tensor names
    output_names=['encoder_hidden_states_out'],         # Output tensor name
    dynamic_axes=dynamic_axes,
    opset_version=17                       # Dynamic axes for variable-length inputs
)

#### Export the decoder layer

In [4]:
class DecoderWrapper(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder
        self.past_key_values = None

    def forward(self, input_ids, encoder_hidden_states, encoder_attention_mask):
        outputs = self.decoder(
            input_ids = input_ids,
            attention_mask = None,
            encoder_hidden_states = encoder_hidden_states,
            encoder_attention_mask = encoder_attention_mask,
            head_mask = None,
            cross_attn_head_mask = None,
            past_key_values = self.past_key_values,
            inputs_embeds = None,
            labels = None,
            use_cache = True,
            output_attentions = False,
            output_hidden_states = False,
            return_dict = True
        )

        self.past_key_values = outputs.past_key_values

        logits = outputs.logits

        return logits

In [5]:
decoder_wrapper = DecoderWrapper(model.decoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size'},
    'encoder_hidden_states': {0: 'batch_size', 1: 'sequence_length'},
    'encoder_attention_mask': {0: 'batch_size', 1: 'sequence_length'}
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (16, 1), dtype=torch.int64)
dummy_encoder_hidden_states = torch.randn((2, 18, 1024), dtype=torch.float32)
dummy_encoder_attention_mask = torch.randint(0, 100, (2, 18), dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    decoder_wrapper,
    (
        dummy_input_ids,
        dummy_encoder_hidden_states,
        dummy_encoder_attention_mask,
    ),
    f"{folder}/decoder.onnx",
    input_names=[
        'input_ids',
        'encoder_hidden_states',
        'encoder_attention_mask',
    ],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  if input_shape[-1] > 1 or self.sliding_window is not None:
  if seq_len > self.weights.size(0):
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):


#### Export the audio_encoder

Make a PT wrapper for the decoding portion of the model

In [None]:
class DecodeAudioWrapper(nn.Module):
    def __init__(self, audio_encoder):
        super().__init__()
        self.audio_encoder = audio_encoder

    def apply_delay_pattern_mask(self, input_ids, decoder_pad_token_mask):
        seq_len = input_ids.shape[-1]
        decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
        input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
        return input_ids

    def forward(self, output_ids: torch.Tensor, decoder_delay_pattern_mask: torch.Tensor, pad_token_id: int):
        '''Taken from last section of the model'''

        batch_size = 1 # We will only allow sampling of single samples for now, otherwise it might be too slow

        # apply the pattern mask to the final ids
        output_ids = self.apply_delay_pattern_mask(output_ids, decoder_delay_pattern_mask)

        # revert the pattern delay mask by filtering the pad token id
        output_ids = output_ids[output_ids != pad_token_id].reshape(
            batch_size, 8, -1
        )

        # append the frame dimension back to the audio codes
        output_ids = output_ids[None, ...]

        audio_scales = [None] * batch_size

        codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
        output_values_left = codec_outputs_left.audio_values

        codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
        output_values_right = codec_outputs_right.audio_values

        output_values = torch.cat([output_values_left, output_values_right], dim=1)

        return output_values

In [None]:
audio_decoder_wrapper = DecodeAudioWrapper(model.audio_encoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'output_ids': {1: 'sequence_length'}, # Allow variable batch size and sequence length
    'decoder_delay_pattern_mask': {1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_output_ids = torch.randint(0, 100, (8, 257), dtype=torch.int64)
dummy_decoder_delay_pattern_mask = torch.randint(0, 100, (8, 257), dtype=torch.int64)
dummy_pad_token_id = torch.tensor(2048, dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    audio_decoder_wrapper,
    (dummy_output_ids, dummy_decoder_delay_pattern_mask, dummy_pad_token_id),
    f"{folder}/audio_token_decoder.onnx", 
    input_names=[
        'output_ids',
        'decoder_delay_pattern_mask',
        'pad_token_id'
    ],
    output_names=['output_values'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

#### Export the pre_loop

In [None]:
class PreLoop(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, inputs, )

#### Export the inner_loop

Then export this version

### Ignore

In [None]:
ort_session = ort.InferenceSession(f"{folder}/audio_token_decoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
input_ids_np = output_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = decoder_delay_pattern_mask.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
pad_token_np = torch.tensor(2048, dtype=torch.int64).detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    # 'input_ids': np.expand_dims(np.concatenate((input_ids_np, attention_mask_np), axis=0), 0),
    'output_ids': input_ids_np,
    'decoder_delay_pattern_mask': attention_mask_np,
    'pad_token_id': pad_token_np
}
encoded = ort_session.run(None, ort_inputs)[0]

In [12]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.text_encoder.encoder.forward))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder.quantizer.decode))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder._decode_frame))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder.decode))

In [None]:
model.audio_encoder.config

In [None]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states_in': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states_out': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 12, 768), dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    model.enc_to_dec_proj,                             # Model to export
    (dummy_encoder_hidden_states,),                             # Example input tuple
    f"{folder}/enc_to_dec_proj.onnx",               # Export path
    input_names=['encoder_hidden_states_in'],          # Input tensor names
    output_names=['encoder_hidden_states_out'],         # Output tensor name
    dynamic_axes=dynamic_axes                       # Dynamic axes for variable-length inputs
)

Export the audio_encoder