In [None]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch, os, glob, json, math
from torch import nn
import numpy as np
import onnxruntime as ort
import inspect

processor = AutoProcessor.from_pretrained("facebook/musicgen-stereo-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")

folder = './musicgen-stereo-small'
os.makedirs(folder, exist_ok=True)

Test run throughputs

In [None]:
inputs = processor(
    text=["80s pop track with bassy drums and synth asd asd asd"],
    padding=True,
    return_tensors="pt",
)
raise

res = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=259, temperature=2.0, top_k=500, top_p=0.0)

### Exporting the model
TODO:
- [x] Configuration lists exported
- [x] Text encoder exported
- [x] Projection layer exported
- [x] Decoder layer exported
- [ ] Sampling function exported
- [x] Output decoder exported
- [ ] Look at making layers efficient
- [ ] Full model throughput test
- [ ] Research way to export the sample input version

Flow:                                                           This will be in a forloop
tokenized inputs and mask -> Text Encoder -> PreLoop -> [Sample] -> Audio Encoder -> Wav

Export configs

In [3]:
processor.tokenizer.save_pretrained(f'{folder}')
processor.save_pretrained(f'{folder}')
model.config.to_json_file(f'{folder}/config.json')
model.generation_config.to_json_file(f'{folder}/generation_config.json')

#### Export text encoder

In [38]:
class TextEncoderWrapper(nn.Module):
    def __init__(self, text_encoder):
        super().__init__()
        self.text_encoder = text_encoder

    def forward(self, input_ids, attention_mask, cfg=None):
        last_hidden_state = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        if cfg is not None:
            cfg_tensor = cfg.unsqueeze(0)  # Convert to tensor for ONNX
            condition = (cfg_tensor > 1).float()  # Create a condition tensor
            if condition: # This enforces the addition of cfg as a variable
                last_hidden_state = condition * torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
                res_attention_mask = condition * torch.concatenate([attention_mask, torch.zeros_like(attention_mask)], dim=0)
        else:
            res_attention_mask = attention_mask

        return last_hidden_state, res_attention_mask

In [39]:
text_encoder_wrapper = TextEncoderWrapper(model.text_encoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoded': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (1, 18), dtype=torch.int64)
dummy_attention_mask = torch.randint(0, 100, (1, 18), dtype=torch.int64)
dummy_cfg = torch.tensor(3, dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    text_encoder_wrapper,
    (dummy_input_ids, dummy_attention_mask, dummy_cfg),
    f"{folder}/text_encoder.onnx",
    input_names=['input_ids', 'attention_mask', 'cfg'],
    output_names=['last_hidden_state', 'res_attention_mask'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  if condition: # This enforces the addition of cfg as a variable


#### Export the projection layer

In [6]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states_in': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states_out': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 12, 768), dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    model.enc_to_dec_proj,                             # Model to export
    (dummy_encoder_hidden_states,),                             # Example input tuple
    f"{folder}/enc_to_dec_proj.onnx",               # Export path
    input_names=['encoder_hidden_states_in'],          # Input tensor names
    output_names=['encoder_hidden_states_out'],         # Output tensor name
    dynamic_axes=dynamic_axes,
    opset_version=17                       # Dynamic axes for variable-length inputs
)

#### Export the decoder layer

In [7]:
class DecoderWrapper(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder
        self.past_key_values = None

    def forward(self, input_ids, encoder_hidden_states, encoder_attention_mask):
        outputs = self.decoder(
            input_ids = input_ids,
            attention_mask = None,
            encoder_hidden_states = encoder_hidden_states,
            encoder_attention_mask = encoder_attention_mask,
            head_mask = None,
            cross_attn_head_mask = None,
            past_key_values = self.past_key_values,
            inputs_embeds = None,
            labels = None,
            use_cache = True,
            output_attentions = False,
            output_hidden_states = False,
            return_dict = True
        )

        self.past_key_values = outputs.past_key_values

        logits = outputs.logits

        return logits

In [None]:
decoder_wrapper = DecoderWrapper(model.decoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size'},
    'encoder_hidden_states': {0: 'batch_size', 1: 'sequence_length'},
    'encoder_attention_mask': {0: 'batch_size', 1: 'sequence_length'}
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (16, 1), dtype=torch.int64)
dummy_encoder_hidden_states = torch.randn((2, 18, 1024), dtype=torch.float32)
dummy_encoder_attention_mask = torch.randint(0, 100, (2, 18), dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    decoder_wrapper,
    (
        dummy_input_ids,
        dummy_encoder_hidden_states,
        dummy_encoder_attention_mask,
    ),
    f"{folder}/decoder.onnx",
    input_names=[
        'input_ids',
        'encoder_hidden_states',
        'encoder_attention_mask',
    ],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

#### Export the audio_encoder

Make a PT wrapper for the decoding portion of the model

In [40]:
class DecodeAudioWrapper(nn.Module):
    def __init__(self, audio_encoder):
        super().__init__()
        self.audio_encoder = audio_encoder

    def apply_delay_pattern_mask(self, input_ids, decoder_pad_token_mask):
        seq_len = input_ids.shape[-1]
        decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
        input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
        return input_ids

    def forward(self, output_ids: torch.Tensor, decoder_delay_pattern_mask: torch.Tensor, pad_token_id: int):
        '''Taken from last section of the model'''

        batch_size = 1 # We will only allow sampling of single samples for now, otherwise it might be too slow

        # apply the pattern mask to the final ids
        output_ids = self.apply_delay_pattern_mask(output_ids, decoder_delay_pattern_mask)

        # revert the pattern delay mask by filtering the pad token id
        output_ids = output_ids[output_ids != pad_token_id].reshape(
            batch_size, 8, -1
        )

        # append the frame dimension back to the audio codes
        output_ids = output_ids[None, ...]

        audio_scales = [None] * batch_size

        codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
        output_values_left = codec_outputs_left.audio_values

        codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
        output_values_right = codec_outputs_right.audio_values

        output_values = torch.cat([output_values_left, output_values_right], dim=1)

        return output_values

In [41]:
audio_decoder_wrapper = DecodeAudioWrapper(model.audio_encoder)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'output_ids': {1: 'sequence_length'}, # Allow variable batch size and sequence length
    'decoder_delay_pattern_mask': {1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_output_ids = torch.randint(0, 100, (8, 257), dtype=torch.int64)
dummy_decoder_delay_pattern_mask = torch.randint(0, 100, (8, 257), dtype=torch.int64)
dummy_pad_token_id = torch.tensor([2048], dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    audio_decoder_wrapper,
    (dummy_output_ids, dummy_decoder_delay_pattern_mask, dummy_pad_token_id),
    f"{folder}/audio_token_decoder.onnx", 
    input_names=[
        'output_ids',
        'decoder_delay_pattern_mask',
        'pad_token_id'
    ],
    output_names=['output_values'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  if len(audio_codes) != 1:
  quantized_out = torch.tensor(0.0, device=codes.device)
  for i, indices in enumerate(codes):
  max_pad = max(padding_left, padding_right)
  if length <= max_pad:


#### Export the pre_loop

In [42]:
class PreLoop(nn.Module):
    def __init__(self, num_codebooks=8, audio_channels=2):
        super().__init__()
        self.num_codebooks = num_codebooks
        self.audio_channels = audio_channels

    @staticmethod
    def build_delay_pattern_mask(self, input_ids, pad_token_id, max_length):
        input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
        bsz, num_codebooks, seq_len = input_ids.shape

        max_length = max_length.unsqueeze(0)  # Convert to tensor for ONNX
        condition = (max_length > 0).long()  # Create a condition tensor

        input_ids_shifted = condition * (
            torch.ones((bsz, num_codebooks, max_length.item()), dtype=torch.long, device=input_ids.device) * -1
        )

        max_length = max_length.item()

        channel_codebooks = num_codebooks // 2 if self.audio_channels == 2 else num_codebooks
        # we only apply the mask if we have a large enough seq len - otherwise we return as is
        if max_length < 2 * channel_codebooks - 1:
            return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)

        # fill the shifted ids with the prompt entries, offset by the codebook idx
        for codebook in range(channel_codebooks):
            if self.audio_channels == 1:
                # mono channel - loop over the codebooks one-by-one
                input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
            else:
                # left/right channels are interleaved in the generated codebooks, so handle one then the other
                input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
                input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
        # construct a pattern mask that indicates the positions of padding tokens for each codebook
        # first fill the upper triangular part (the EOS padding)
        delay_pattern = torch.triu(
            torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
        ).to(torch.int64)
        # then fill the lower triangular part (the BOS padding)
        delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))

        if self.audio_channels == 2:
            # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
            delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

        delay_pattern = delay_pattern.to(torch.bool)

        mask = ~delay_pattern.to(input_ids.device)
        input_ids = mask * input_ids_shifted + ~mask * pad_token_id

        # find the first position to start generating - this is the first place we have the -1 token
        # and will always be in the first codebook (since it has no codebook offset)
        first_codebook_ids = input_ids[:, 0, :]
        start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
        if len(start_ids) > 0:
            first_start_id = min(start_ids)
        else:
            # we have no tokens that need to be filled - return entire matrix of input ids
            first_start_id = seq_len

        # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
        pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
        input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
        
        return input_ids, pattern_mask 

    def forward(self, batch_size: torch.Tensor, decoder_input_ids=None, decoder_attention_mask=None, max_length=torch.tensor(256, dtype=torch.int64)):
        # TODO: Impl for audio input (uses decoder_input_ids and decoder_attention_mask)
        # Equal to #5 _prepare_decoder_input_ids_for_generation
        decoder_start_token_id = 2048
        pad_token_id = 2048
        decoder_input_ids_start = (
            torch.ones((batch_size * self.num_codebooks, 1), dtype=torch.long) * decoder_start_token_id
        )

        if decoder_input_ids is None:
            decoder_input_ids = decoder_input_ids_start
        elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
            decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask
                decoder_attention_mask = torch.cat(
                    (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
                    dim=-1,
                )
                decoder_attention_mask = decoder_attention_mask

        # Build delay pattern mask
        decoder_input_ids, decoder_delay_pattern_mask = self.build_delay_pattern_mask(self=self, input_ids=decoder_input_ids, pad_token_id=pad_token_id, max_length=max_length)
        
        return decoder_input_ids, decoder_delay_pattern_mask

In [43]:
pre_loop = PreLoop(model.config.decoder.num_codebooks, model.config.decoder.audio_channels)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    # 'decoder_input_ids': {0: 'batch_size', 1: 'sequence_length'},
    # 'decoder_attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'decoder_input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'decoder_delay_pattern_mask': {0: 'batch_size', 1: 'sequence_length'}
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_batch_size = torch.tensor(1, dtype=torch.int64)
dummy_max_length = torch.tensor(256, dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    pre_loop,
    (dummy_batch_size, None, None, dummy_max_length),
    f"{folder}/pre_loop.onnx",
    input_names=['batch_size', 'max_length'],
    output_names=['decoder_input_ids', 'decoder_delay_pattern_mask'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  torch.ones((bsz, num_codebooks, max_length.item()), dtype=torch.long, device=input_ids.device) * -1
  max_length = max_length.item()
  if max_length < 2 * channel_codebooks - 1:
  if len(start_ids) > 0:
  first_start_id = min(start_ids)
  first_start_id = min(start_ids)


#### Export Sample

In [44]:
class Sample(nn.Module):
    def __init__(self, decoder, enc_proj):
        super().__init__()
        self.decoder = decoder
        self.enc_proj = enc_proj
        self.past_key_values = None
        self.filter_value = -float('inf')

    @staticmethod
    def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
        """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
        the mask is set to -1, and otherwise setting to the value detailed in the mask."""
        seq_len = input_ids.shape[-1]
        decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
        input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
        return input_ids
    
    def logits_process(self, next_token_logits, cfg):
        # ClassifierFreeGuidanceLogitsProcessor
        unguided_bsz = next_token_logits.shape[0] // 2
        cond_logits, uncond_logits = next_token_logits.split(unguided_bsz, dim=0)
        next_token_scores = uncond_logits + (cond_logits - uncond_logits) * cfg
        return next_token_scores

    def logits_warp(self, scores: torch.Tensor, temperature: torch.Tensor, topk: torch.Tensor, topp: torch.Tensor):
        # Temperature
        scores_processed = scores / temperature

        # Topk
        top_k = min(topk, scores_processed.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = scores_processed < torch.topk(scores_processed, top_k)[0][..., -1, None]
        scores_processed = scores_processed.masked_fill(indices_to_remove, self.filter_value)

        # Topp
        sorted_logits, sorted_indices = torch.sort(scores_processed, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - topp)
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -1 :] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores_processed = scores_processed.masked_fill(indices_to_remove, self.filter_value)
        return scores_processed

    def forward(
            self, 
            decoder_input_ids, 
            attention_mask, 
            encoder_hidden_states, 
            delay_pattern_mask, 
            cfg = torch.tensor(3),
            temperature = torch.tensor(0.7), 
            topk = torch.tensor(500), 
            topp = torch.tensor(0.0)
        ):

        # Input prep
        model_inputs = self.apply_delay_pattern_mask(decoder_input_ids, delay_pattern_mask)

        if cfg is not None:
            cfg_tensor = cfg.unsqueeze(0)  # Convert to tensor for ONNX
            condition = (cfg_tensor > 1).int()  # Create a condition tensor
            if condition:
                model_inputs = condition * model_inputs.repeat((2,1))
                if attention_mask is not None:
                    model_input_attention_mask = condition * attention_mask.repeat((2,1))
        
        if self.past_key_values is not None:
            model_inputs = model_inputs[:, -1:]

        # Forward Loop
        encoder_hidden_states = self.enc_proj(encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states * model_input_attention_mask[..., None]

        outputs = self.decoder(
            input_ids = model_inputs,
            attention_mask = None,
            encoder_hidden_states = encoder_hidden_states,
            encoder_attention_mask = model_input_attention_mask,
            head_mask = None,
            cross_attn_head_mask = None,
            past_key_values = self.past_key_values,
            inputs_embeds = None,
            labels = None,
            use_cache = True,
            output_attentions = False,
            output_hidden_states = False,
            return_dict = True
        )

        self.past_key_values = outputs.past_key_values

        next_token_logits = outputs.logits[:, -1, :].clone()

        # CFG processing if cfg is large enough, aka logits_processlist
        if condition:
            next_token_scores = self.logits_process(next_token_logits, cfg)
        else:
            next_token_scores = next_token_logits

        next_token_scores = self.logits_warp(next_token_scores, temperature, topk, topp)

        probs = nn.functional.softmax(next_token_scores, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        decoder_input_ids = torch.cat([decoder_input_ids, next_tokens[:, None]], dim=-1)

        return decoder_input_ids

In [45]:
sample = Sample(model.decoder, model.enc_to_dec_proj)
sample.eval()

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'decoder_input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'encoder_hidden_states': {0: 'batch_size', 1: 'sequence_length'},
    'delay_pattern_mask': {0: 'batch_size', 1: 'sequence_length'},
    'decoder_input_ids': {0: 'batch_size', 1: 'sequence_length'}
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_decoder_input_ids = torch.randint(0, 100, (8, 1), dtype=torch.int64)
dummy_attention_mask = torch.randint(0, 100, (1, 18), dtype=torch.int64)
dummy_encoder_hidden_states = torch.randn((2, 18, 768), dtype=torch.float32)
dummy_delay_pattern_mask = torch.randint(0, 100, (8, 260), dtype=torch.int64)
dummy_cfg = torch.tensor(3, dtype=torch.int64)
dummy_temperature = torch.tensor(0.7, dtype=torch.float32)
dummy_topk = torch.tensor(500, dtype=torch.int64)
dummy_topp = torch.tensor(0.0, dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    sample,
    (
        dummy_decoder_input_ids, 
        dummy_attention_mask, 
        dummy_encoder_hidden_states, 
        dummy_delay_pattern_mask, 
        dummy_cfg, 
        dummy_temperature, 
        dummy_topk, 
        dummy_topp
    ),
    f"{folder}/sampler.onnx",
    input_names=[
        'decoder_input_ids', 
        'attention_mask', 
        'encoder_hidden_states', 
        'delay_pattern_mask', 
        'cfg', 
        'temperature', 
        'topk', 
        'topp'
    ],
    output_names=['decoder_input_ids'],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

  if condition:
  if input_shape[-1] > 1 or self.sliding_window is not None:
  if seq_len > self.weights.size(0):
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  if condition:
  top_k = min(topk, scores_processed.size(-1))  # Safety check


#### Test the generating

In [46]:
ort_session = ort.InferenceSession(f"{folder}/text_encoder.onnx")

input_ids_np = inputs['input_ids'].detach().numpy()
attention_mask_np = inputs['attention_mask'].detach().numpy()

# Run the model
ort_inputs = {
    'input_ids': input_ids_np,
    'attention_mask': attention_mask_np,
    'cfg': np.array([3], dtype=np.int64)
}
encoded = ort_session.run(None, ort_inputs)[0]

In [47]:
ort_session = ort.InferenceSession(f"{folder}/pre_loop.onnx")

dummy_batch_size = torch.tensor(1, dtype=torch.int64).detach().numpy()
dummy_max_length = torch.tensor(256, dtype=torch.int64).detach().numpy()

# Run the model
ort_inputs = {
    'batch_size': dummy_batch_size,
    'max_length': dummy_max_length
}
decoder_input_ids, decoder_delay_pattern_mask = ort_session.run(None, ort_inputs)

In [48]:
ort_session = ort.InferenceSession(f"{folder}/sampler.onnx")

for i in range(256):
    # Run the model
    ort_inputs = {
        'decoder_input_ids.1': decoder_input_ids, 
        'attention_mask': inputs['attention_mask'].detach().numpy(), 
        'encoder_hidden_states': encoded, 
        'delay_pattern_mask': decoder_delay_pattern_mask, 
        'cfg': np.array([3], dtype=np.int64), 
        'temperature': np.array([0.7], dtype=np.float32), 
        'topk': np.array([500], dtype=np.int64), 
        'topp': np.array([0.0], dtype=np.float32)
    }

    decoder_input_ids = ort_session.run(None, ort_inputs)[0]

In [49]:
ort_session = ort.InferenceSession(f"{folder}/audio_token_decoder.onnx")

# Run the model
ort_inputs = {
    'output_ids': decoder_input_ids[:, :-1], # We either need to remove the first tokens or add tokens to the decoder delay_pattern mask, check og for inspo
    'decoder_delay_pattern_mask': decoder_delay_pattern_mask,
    'pad_token_id': np.array([2048], dtype=np.int64)
}

output_values = ort_session.run(None, ort_inputs)[0]

In [None]:
from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
Audio(output_values[0], rate=sampling_rate)

In [50]:
import scipy

sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=output_values[0].T)

### Try and export the full model

Doesnt work bc of problems in attention layer

In [None]:
class MusicGenWrapper(nn.Module):
    def __init__(self, model: MusicgenForConditionalGeneration):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, guidance_scale=3, max_new_tokens=256, temperature=2.0, top_k=500, top_p=0.0):
        '''Taken from last section of the model'''

        inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }

        output_values = model.generate(**inputs, guidance_scale=guidance_scale.item(), max_new_tokens=max_new_tokens.item(), temperature=temperature.item(), top_k=top_k.item(), top_p=top_p.item())

        return output_values

In [None]:
musicgen_wrapper = MusicGenWrapper(model)

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'}, # Allow variable batch size and sequence length
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, # Allow variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_guidance_scale = torch.tensor(3, dtype=torch.int64)
dummy_max_new_tokens = torch.tensor(256, dtype=torch.int64)
dummy_temperature = torch.tensor(2.0, dtype=torch.float32)
dummy_top_k = torch.tensor(500, dtype=torch.int64)
dummy_top_p = torch.tensor(0.0, dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    musicgen_wrapper,
    (inputs['input_ids'], inputs['attention_mask'], dummy_guidance_scale, dummy_max_new_tokens, dummy_temperature, dummy_top_k, dummy_top_p),
    f"{folder}/musicgen.onnx", 
    input_names=[
        'input_ids',
        'attention_mask',
        'guidance_scale',
        'max_new_tokens',
        'temperature',
        'top_k',
        'top_p'
    ],
    output_names=['output_values'],
    dynamic_axes=dynamic_axes
)

### Ignore

In [None]:
ort_session = ort.InferenceSession(f"{folder}/audio_token_decoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
input_ids_np = output_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = decoder_delay_pattern_mask.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
pad_token_np = torch.tensor(2048, dtype=torch.int64).detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    # 'input_ids': np.expand_dims(np.concatenate((input_ids_np, attention_mask_np), axis=0), 0),
    'output_ids': input_ids_np,
    'decoder_delay_pattern_mask': attention_mask_np,
    'pad_token_id': pad_token_np
}
encoded = ort_session.run(None, ort_inputs)[0]

In [None]:
import onnx

# Load the ONNX model
model_path = f"{folder}/sampler.onnx"  # Update this with your ONNX model path
onnx_model = onnx.load(model_path)

# Print model input names and their shapes
print("Model Inputs:")
for input_tensor in onnx_model.graph.input:
    print(f"Input name: {input_tensor.name}")


In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.text_encoder.encoder.forward))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder.quantizer.decode))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder._decode_frame))

In [None]:
with open("forward_method_code.py", "w") as file:
    file.write(inspect.getsource(model.audio_encoder.decode))

In [None]:
model.audio_encoder.config

In [None]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states_in': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states_out': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 12, 768), dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    model.enc_to_dec_proj,                             # Model to export
    (dummy_encoder_hidden_states,),                             # Example input tuple
    f"{folder}/enc_to_dec_proj.onnx",               # Export path
    input_names=['encoder_hidden_states_in'],          # Input tensor names
    output_names=['encoder_hidden_states_out'],         # Output tensor name
    dynamic_axes=dynamic_axes                       # Dynamic axes for variable-length inputs
)

Export the audio_encoder

In [None]:
ort_session = ort.InferenceSession(f"{folder}/text_encoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
input_ids_np = inputs['input_ids'].detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = inputs['attention_mask'].detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    'input_ids': input_ids_np,
    'attention_mask': attention_mask_np,
    'cfg': np.array([3], dtype=np.int64)
}
encoded = ort_session.run(None, ort_inputs)[0]
encoded