## Exporting the stereo model

Trick to make the model actually function

In [None]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, MusicgenModel

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")
x = model.config.to_dict()
x['decoder']['num_codebooks'] = 4
model.config = model.config.from_dict(x)
model.save_pretrained("musicgen_fixed")

In [None]:
import soundfile as sf


sf.write('./test.wav', audio_values.detach().numpy()[0].T, samplerate=41000)

Try to replicate the export method

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained("facebook/musicgen-stereo-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")

  from .autonotebook import tqdm as notebook_tqdm
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [32]:
import math
from torch import nn
import random
from typing import List, Optional, Tuple, Union

def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`

    Args:
        mask (`torch.Tensor` or `None`):
            A 2D attention mask of shape `(batch_size, key_value_length)`
        dtype (`torch.dtype`):
            The torch dtype the created mask shall have.
        tgt_len (`int`):
            The target length or query length the created mask shall have.
    """
    return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)

def _prepare_4d_causal_attention_mask(
    attention_mask: Optional[torch.Tensor],
    input_shape: Union[torch.Size, Tuple, List],
    inputs_embeds: torch.Tensor,
    past_key_values_length: int,
    sliding_window: Optional[int] = None,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`

    Args:
        attention_mask (`torch.Tensor` or `None`):
            A 2D attention mask of shape `(batch_size, key_value_length)`
        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
            The input shape should be a tuple that defines `(batch_size, query_length)`.
        inputs_embeds (`torch.Tensor`):
            The embedded inputs as a torch Tensor.
        past_key_values_length (`int`):
            The length of the key value cache.
        sliding_window (`int`, *optional*):
            If the model uses windowed attention, a sliding window should be passed.
    """
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)

    key_value_length = input_shape[-1] + past_key_values_length

    # 4d mask is passed through the layers
    if attention_mask is not None and len(attention_mask.shape) == 2:
        attention_mask = attn_mask_converter.to_4d(
            attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
        )
    elif attention_mask is not None and len(attention_mask.shape) == 4:
        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
        if tuple(attention_mask.shape) != expected_shape:
            raise ValueError(
                f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
            )
        else:
            # if the 4D mask has correct shape - invert it and fill with negative infinity
            inverted_mask = 1.0 - attention_mask
            attention_mask = inverted_mask.masked_fill(
                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
            )
    else:
        attention_mask = attn_mask_converter.to_causal_4d(
            input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
        )

    return attention_mask

class AttentionMaskConverter:
    """
    A utility attention mask class that allows one to:
        - Create a causal 4d mask
        - Create a causal 4d mask with slided window
        - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
          key_value_length) that can be multiplied with attention scores

    Examples:

    ```python
    >>> import torch
    >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter

    >>> converter = AttentionMaskConverter(True)
    >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
    tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38],
            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]])
    ```

    Parameters:
        is_causal (`bool`):
            Whether the attention mask should be a uni-directional (causal) or bi-directional mask.

        sliding_window (`int`, *optional*):
            Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
    """

    is_causal: bool
    sliding_window: int

    def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
        self.is_causal = is_causal
        self.sliding_window = sliding_window

        if self.sliding_window is not None and self.sliding_window <= 0:
            raise ValueError(
                f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
            )

    def to_causal_4d(
        self,
        batch_size: int,
        query_length: int,
        key_value_length: int,
        dtype: torch.dtype,
        device: Union[torch.device, "str"] = "cpu",
    ) -> Optional[torch.Tensor]:
        """
        Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
        bias to upper right hand triangular matrix (causal mask).
        """
        if not self.is_causal:
            raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")

        # If shape is not cached, create a new causal mask and cache it
        input_shape = (batch_size, query_length)
        past_key_values_length = key_value_length - query_length

        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        causal_4d_mask = None
        if input_shape[-1] > 1 or self.sliding_window is not None:
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                dtype,
                device=device,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )

        return causal_4d_mask

    def to_4d(
        self,
        attention_mask_2d: torch.Tensor,
        query_length: int,
        dtype: torch.dtype,
        key_value_length: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
        key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
        causal, a causal mask will be added.
        """
        input_shape = (attention_mask_2d.shape[0], query_length)

        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        causal_4d_mask = None
        if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
            if key_value_length is None:
                raise ValueError(
                    "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
                )

            past_key_values_length = key_value_length - query_length
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                dtype,
                device=attention_mask_2d.device,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )
        elif self.sliding_window is not None:
            raise NotImplementedError("Sliding window is currently only implemented for causal masking")

        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
            attention_mask_2d.device
        )

        if causal_4d_mask is not None:
            expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)

        # expanded_attn_mask + causal_4d_mask can cause some overflow
        expanded_4d_mask = expanded_attn_mask

        return expanded_4d_mask

    @staticmethod
    def _make_causal_mask(
        input_ids_shape: torch.Size,
        dtype: torch.dtype,
        device: torch.device,
        past_key_values_length: int = 0,
        sliding_window: Optional[int] = None,
    ):
        """
        Make causal mask used for bi-directional self-attention.
        """
        bsz, tgt_len = input_ids_shape
        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
        mask_cond = torch.arange(mask.size(-1), device=device)
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

        mask = mask.to(dtype)

        if past_key_values_length > 0:
            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)

        # add lower triangular sliding window mask if necessary
        if sliding_window is not None:
            diagonal = past_key_values_length - sliding_window - 1

            context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
            mask.masked_fill_(context_mask, torch.finfo(dtype).min)

        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

    @staticmethod
    def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

        inverted_mask = 1.0 - expanded_mask

        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

    @staticmethod
    def _unmask_unattended(
        expanded_mask: torch.FloatTensor,
        min_dtype: float,
    ):
        # fmt: off
        """
        Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
        using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
        Details: https://github.com/pytorch/pytorch/issues/110213

        `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
        `attention_mask` is [bsz, src_seq_len].

        The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.

        For example, if `expanded_mask` is (e.g. here left-padding case)
        ```
        [[[[0, 0, 0],
           [0, 0, 0],
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[0, 0, 0],
           [0, 1, 0],
           [0, 1, 1]]]]
        ```
        then the modified `expanded_mask` will be
        ```
        [[[[1, 1, 1],   <-- modified
           [1, 1, 1],   <-- modified
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[1, 1, 1],   <-- modified
           [0, 1, 0],
           [0, 1, 1]]]]
        ```
        """
        # fmt: on
        if expanded_mask.dtype == torch.bool:
            raise ValueError(
                "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
            )

        return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))

    @staticmethod
    def _ignore_causal_mask_sdpa(
        attention_mask: Optional[torch.Tensor],
        inputs_embeds: torch.Tensor,
        past_key_values_length: int,
        sliding_window: Optional[int] = None,
        is_training: bool = False,
    ) -> bool:
        """
        Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.

        In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
        `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
        allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
        """

        _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
        key_value_length = query_length + past_key_values_length

        is_tracing = (
            torch.jit.is_tracing()
            or isinstance(inputs_embeds, torch.fx.Proxy)
            or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
        )

        ignore_causal_mask = False

        if attention_mask is None:
            # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
            # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
            # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
            #
            # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
            if (
                (is_training or not is_tracing)
                and (query_length == 1 or key_value_length == query_length)
                and (sliding_window is None or key_value_length < sliding_window)
            ):
                ignore_causal_mask = True
        elif sliding_window is None or key_value_length < sliding_window:
            if len(attention_mask.shape) == 4:
                return False
            elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
                if query_length == 1 or key_value_length == query_length:
                    # For query_length == 1, causal attention and bi-directional attention are the same.
                    ignore_causal_mask = True

                # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
                # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
                # Reference: https://github.com/pytorch/pytorch/issues/108108
                # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.

        return ignore_causal_mask

class MusicgenAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config

        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states  = None,
        past_key_value = None,
        attention_mask  = None,
        layer_head_mask  = None,
        output_attentions = False,
    ):
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        src_len = key_states.size(1)
        query_states = query_states#.repeat(key_states.size(1))
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
        #     raise ValueError(
        #         f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
        #         f" {attn_weights.size()}"
        #     )

        if attention_mask is not None:
            # if attention_mask.size() != (bsz, 1, tgt_len, src_len):
            #     raise ValueError(
            #         f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
            #     )
            if attention_mask.size(1) == 1:  # If num_heads dimension is 1, expand it to match attn_weights
                attention_mask = attention_mask.expand(-1, 16, -1, -1)
                attention_mask.to(attn_weights.dtype)
            
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attention_mask = attention_mask.squeeze()
            attn_weights = attn_weights.squeeze()
            print(attn_weights.size(), attention_mask.size())
            attn_weights = torch.add(attn_weights, attention_mask)
            # attn_weights = attn_weights + attention_mask
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            # if layer_head_mask.size() != (self.num_heads,):
            #     raise ValueError(
            #         f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
            #         f" {layer_head_mask.size()}"
            #     )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None


        attn_output = torch.bmm(attn_weights, value_states)

        # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
        #     raise ValueError(
        #         f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
        #         f" {attn_output.size()}"
        #     )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned across GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value

class MusicgenSinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions: int, embedding_dim: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.make_weights(num_positions, embedding_dim)

    def make_weights(self, num_embeddings: int, embedding_dim: int):
        emb_weights = self.get_embedding(num_embeddings, embedding_dim)
        if hasattr(self, "weights"):
            # in forward put the weights on the correct dtype and device of the param
            emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)

        self.weights = nn.Parameter(emb_weights)
        self.weights.requires_grad = False
        self.weights.detach_()

    @staticmethod
    def get_embedding(num_embeddings: int, embedding_dim: int):
        """
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        return emb.to(torch.get_default_dtype())

    @torch.no_grad()
    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        bsz, codebooks, seq_len = input_ids.size()
        # Create the position ids from the input token ids.
        position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
        # expand embeddings if needed
        if seq_len > self.weights.size(0):
            self.make_weights(seq_len + self.offset, self.embedding_dim)
        return self.weights.index_select(0, position_ids.view(-1)).detach()

class MusicgenDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config['hidden_size']

        self.self_attn = MusicgenAttention(
            embed_dim=self.embed_dim,
            num_heads=config['num_attention_heads'],
            dropout=config['attention_dropout'],
            is_decoder=True,
            bias=False,
            is_causal=True,
            config=config,
        )

        self.dropout = config['dropout']
        self.activation_fn = nn.GELU()
        self.activation_dropout = config['activation_dropout']

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = MusicgenAttention(
            self.embed_dim,
            config['num_attention_heads'],
            dropout=config['attention_dropout'],
            is_decoder=True,
            bias=False,
            config=config,
        )
        
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config['ffn_dim'], bias=False)
        self.fc2 = nn.Linear(config['ffn_dim'], self.embed_dim, bias=False)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask = None,
        encoder_hidden_states = None,
        encoder_attention_mask = None,
        layer_head_mask = None,
        cross_attn_layer_head_mask = None,
        past_key_value = None,
        output_attentions = False,
        use_cache = True,
    ) -> torch.Tensor:
        
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # add present self-attn cache to positions 1,2 of present_key_value tuple
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=self_attn_past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
            )
            hidden_states = residual + hidden_states

            # add cross-attn to positions 3,4 of present_key_value tuple
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

class NakedDecoder(torch.nn.Module):
    '''Class to store a dressed down version of the MusicDecoder'''

    def __init__(self, config, attn_impl):
        super().__init__()
        self.config = config

        # Where self.model was going to go (based on MusicgenForCausalLM)
        self.dropout = config['dropout']
        self.layerdrop = config['layerdrop']
        self.max_target_positions = config['max_position_embeddings']
        self.d_model = config['hidden_size']
        self.num_codebooks = config['num_codebooks']
        self.embed_scale = math.sqrt(config['hidden_size']) if config['scale_embedding'] else 1.0

        embed_dim = config['vocab_size'] + 1
        self.embed_tokens = torch.nn.ModuleList(
            [torch.nn.Embedding(embed_dim, config['hidden_size']) for _ in range(config['num_codebooks'])]
        )

        self.embed_positions = MusicgenSinusoidalPositionalEmbedding( # Same code as in the original model
            config['max_position_embeddings'],
            config['hidden_size'],
        )

        self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config['num_hidden_layers'])]) # This is otherwise too much code to copy
        self.layer_norm = torch.nn.LayerNorm(config['hidden_size'])
        self.attn_implementation = attn_impl

        self.num_books = self.config['num_codebooks']
        self.lm_heads = torch.nn.ModuleList(
            [torch.nn.Linear(config['hidden_size'], config['vocab_size'], bias=False) for _ in range(self.num_codebooks)]
        )

    def forward(self, input_ids, encoder_hidden_states, encoder_attention_mask):
        input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
        bsz, num_codebooks, seq_len = input.shape
        input_shape = (bsz, seq_len)

        past_key_values_length = 0
        inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])

        attention_mask = _prepare_4d_causal_attention_mask(
            None, input_shape, inputs_embeds, past_key_values_length
        )

        encoder_attention_mask = _prepare_4d_attention_mask(
            encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
        )
        
        positions = self.embed_positions(input, past_key_values_length)

        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)


        next_decoder_cache = ()

        for idx, decoder_layer in enumerate(self.layers):
            past_key_value = None

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                layer_head_mask=None,
                cross_attn_layer_head_mask=None,
                past_key_value=past_key_value,
                output_attentions=False,
                use_cache=True,
            )
            hidden_states = layer_outputs[0]

            next_decoder_cache += (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)

        next_cache = next_decoder_cache

        lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)

        # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
        lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
        
        return lm_logits, next_cache
    
test_model = NakedDecoder(model.config.to_dict()['decoder'], model.config.decoder._attn_implementation)
test_model.load_state_dict({k.replace('model.', '').replace('decoder.', ''):v for k,v in model.decoder.state_dict().items()}) # Load the weights of the actual model
test_model.eval()
test_model(
    input_ids=decoder_input_ids,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=attention_mask
)

torch.Size([2, 16, 27]) torch.Size([2, 16, 1, 27])


RuntimeError: The size of tensor a (2) must match the size of tensor b (16) at non-singleton dimension 1

In [28]:
import onnx

# Load the ONNX model
model_path = f"musicgen-stereo/decoder.onnx"  # Replace with your model path
onnx_model = onnx.load(model_path)

# Print the ONNX graph to see the details of layers and nodes
print(onnx.helper.printable_graph(onnx_model.graph))

for node in onnx_model.graph.node:
    if 'encoder_attn' in node.name:
        print(f"Node name: {node.name}, Operation type: {node.op_type}, Inputs: {node.input}, Outputs: {node.output}")


graph main_graph (
  %input_ids[INT64, batch_sizex1]
  %encoder_hidden_states[FLOAT, batch_sizexsequence_lengthx1024]
  %encoder_attention_mask[INT64, batch_sizexsequence_length]
) initializers (
  %embed_tokens.0.weight[FLOAT, 2049x1024]
  %embed_tokens.1.weight[FLOAT, 2049x1024]
  %embed_tokens.2.weight[FLOAT, 2049x1024]
  %embed_tokens.3.weight[FLOAT, 2049x1024]
  %embed_tokens.4.weight[FLOAT, 2049x1024]
  %embed_tokens.5.weight[FLOAT, 2049x1024]
  %embed_tokens.6.weight[FLOAT, 2049x1024]
  %embed_tokens.7.weight[FLOAT, 2049x1024]
  %embed_positions.weights[FLOAT, 2048x1024]
  %layers.0.self_attn_layer_norm.weight[FLOAT, 1024]
  %layers.0.self_attn_layer_norm.bias[FLOAT, 1024]
  %layers.0.encoder_attn_layer_norm.weight[FLOAT, 1024]
  %layers.0.encoder_attn_layer_norm.bias[FLOAT, 1024]
  %layers.0.final_layer_norm.weight[FLOAT, 1024]
  %layers.0.final_layer_norm.bias[FLOAT, 1024]
  %layers.1.self_attn_layer_norm.weight[FLOAT, 1024]
  %layers.1.self_attn_layer_norm.bias[FLOAT, 1024]
 

In [26]:
test_model = NakedDecoder(model.config.to_dict()['decoder'], model.config.decoder._attn_implementation)
test_model.load_state_dict({k.replace('model.', '').replace('decoder.', ''):v for k,v in model.decoder.state_dict().items()}) # Load the weights of the actual model
test_model.eval()

# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_attention_mask': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'output': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (16, 1), dtype=torch.int64)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 27, 1024), dtype=torch.float32)
dummy_encoder_attention_mask = torch.randint(0, 100, (2, 27), dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    test_model,                             # Model to export
    (dummy_input_ids,dummy_encoder_hidden_states,dummy_encoder_attention_mask),                             # Example input tuple
    f"musicgen-stereo/decoder.onnx",                  # Export path
    input_names=['input_ids', 'encoder_hidden_states', 'encoder_attention_mask'],                      # Input tensor names
    output_names=['output'],                       # Output tensor name
    dynamic_axes=dynamic_axes,                       # Dynamic axes for variable-length inputs
    opset_version = 17
)

  if input_shape[-1] > 1 or self.sliding_window is not None:
  if seq_len > self.weights.size(0):
  if attention_mask.size(1) == 1:  # If num_heads dimension is 1, expand it to match attn_weights


In [27]:
# Test the exported model
import onnxruntime as ort
ort_session = ort.InferenceSession(f"musicgen-stereo/decoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)

decoder_input_ids_np = decoder_input_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
encoder_hidden_states_np = encoder_hidden_states.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = decoder_input_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    # 'input_ids': np.expand_dims(np.concatenate((input_ids_np, attention_mask_np), axis=0), 0),
    'input_ids': decoder_input_ids_np,
    'encoder_hidden_states': encoder_hidden_states_np,
    'encoder_attention_mask': attention_mask_np,
}
encoded = ort_session.run(None, ort_inputs)
print(encoded)

[1;31m2024-10-05 15:57:37.462391463 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Add node. Name:'/layers.0/encoder_attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 16 by 27
[m


RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/layers.0/encoder_attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 16 by 27


In [3]:
import onnxruntime as ort
ort_session = ort.InferenceSession(f"musicgen-stereo/decoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)

decoder_input_ids_np = decoder_input_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
encoder_hidden_states_np = encoder_hidden_states.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors
attention_mask_np = decoder_input_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    # 'input_ids': np.expand_dims(np.concatenate((input_ids_np, attention_mask_np), axis=0), 0),
    'input_ids': decoder_input_ids_np,
    'encoder_hidden_states': encoder_hidden_states_np,
    'encoder_attention_mask': attention_mask_np,
}
encoded = ort_session.run(None, ort_inputs)
print(encoded)

NameError: name 'decoder_input_ids' is not defined

### Exports

In [2]:
inputs = processor(
    text=["80s pop track with bassy drums and synth"],
    padding=True,
    return_tensors="pt",
)

# audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

Export the configs

In [15]:
import os, json, glob
import torch.onnx
import onnxruntime as ort
import numpy as np

folder = './musicgen-stereo'
os.makedirs(folder, exist_ok=True)
processor.tokenizer.save_pretrained(f'{folder}')
processor.save_pretrained(f'{folder}')
model.config.to_json_file(f'{folder}/config.json')
model.generation_config.to_json_file(f'{folder}/generation_config.json')

Export the text encoder

In [34]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoded': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (2, 10), dtype=torch.int64)

# Export the model to ONNX format
torch.onnx.export(
    model.text_encoder,                             # Model to export
    (dummy_input_ids,),                             # Example input tuple
    f"{folder}/text_encoder.onnx",                  # Export path
    input_names=['input_ids'],                      # Input tensor names
    output_names=['encoded'],                       # Output tensor name
    dynamic_axes=dynamic_axes                       # Dynamic axes for variable-length inputs
)

Export the extra layer (projection layer for encoder decoder)

In [48]:
model.text_encoder.config.hidden_size, model.decoder.config.hidden_size

(768, 1024)

In [79]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states_in': {0: 'batch_size', 1: 'sequence_length'},  # Allow variable batch size and sequence length
    'encoder_hidden_states_out': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_encoder_hidden_states = torch.randint(0, 100, (2, 12, 768), dtype=torch.float32)

# Export the model to ONNX format
torch.onnx.export(
    model.enc_to_dec_proj,                             # Model to export
    (dummy_encoder_hidden_states,),                             # Example input tuple
    f"{folder}/enc_to_dec_proj.onnx",               # Export path
    input_names=['encoder_hidden_states_in'],          # Input tensor names
    output_names=['encoder_hidden_states_out'],         # Output tensor name
    dynamic_axes=dynamic_axes                       # Dynamic axes for variable-length inputs
)

In [16]:
# Define the dynamic axes for variable-length input shapes
dynamic_axes = {
    'encoder_hidden_states': {0: 'batch_size', 1: 'sequence_length'},  # Output will also have variable batch size and sequence length
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'}  # Output will also have variable batch size and sequence length
}

# Example input shapes (with batch size = 2, sequence length = 10)
dummy_input_ids = torch.randint(0, 100, (16, 1), dtype=torch.int64)
dummy_encoder_hidden_states = torch.randn((2, 27, 1024), dtype=torch.float32)
dummy_encoder_attention_mask = torch.randn((2, 27), dtype=torch.float32)
dummy_output_attentions = torch.tensor(False)
dummy_output_hidden_states = torch.tensor(False)
dummy_use_cache = torch.tensor(True)
dummy_return_dict = torch.tensor(True)

# Export the model to ONNX format
torch.onnx.export(
    model.decoder,                                  # Model to export
    (
        dummy_input_ids,
        dummy_encoder_hidden_states,
        dummy_encoder_attention_mask,
        dummy_output_attentions,
        dummy_output_hidden_states,
        dummy_use_cache,
        dummy_return_dict
    ),                 # Example input tuple
    f"{folder}/decoder.onnx",               # Export path
    input_names=[
        'input_ids',
        'encoder_hidden_states',
        'encoder_attention_mask',
        'output_attentions',
        'output_hidden_states',
        'use_cache',
        'return_dict',
        ],       # Input tensor names
    output_names=['output'],     # Output tensor name
    dynamic_axes=dynamic_axes                       # Dynamic axes for variable-length inputs
)



IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

Export the decoder

In [127]:
print(decoder_input_ids.size())
None
print(encoder_hidden_states.size())
print(attention_mask.size())
None
False
False
True
None
True
None
None

model.decoder(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=attention_mask,
    inputs_embeds=decoder_inputs_embeds,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    use_cache=use_cache,
    past_key_values=past_key_values,
    return_dict=return_dict,
    labels=labels,
    head_mask=head_mask
)[0].size()

torch.Size([16, 1])
torch.Size([2, 27, 1024])
torch.Size([2, 27])


torch.Size([16, 1, 2048])

### Test exported model

In [86]:
max_length = 1024
outputs = processor.tokenizer(["80s pop track with bassy drums and synth"])
input_ids, attention_mask = torch.tensor(outputs['input_ids']), torch.tensor(outputs['attention_mask'])

ort_session = ort.InferenceSession(f"{folder}/text_encoder.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
input_ids_np = input_ids.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    'input_ids': input_ids_np,
}
encoded = ort_session.run(None, ort_inputs)[0]

In [None]:
ort_session = ort.InferenceSession(f"{folder}/enc_to_dec_proj.onnx")

# Prepare input data (assuming you already have input_ids and attention_mask as PyTorch tensors)
encoder_hidden_states_np = encoder_hidden_states.detach().numpy()  # Convert to NumPy arrays if they're in PyTorch tensors

# Run the model
ort_inputs = {
    'encoder_hidden_states_in': encoder_hidden_states_np,
}
ort_outs = ort_session.run(None, ort_inputs)

ort_outs[0]

### Test the inputs

Preprocess the text

In [2]:
max_length = 1024
outputs = processor.tokenizer(["80s pop track with bassy drums and synth asd asd sad asd sdaa sd"])
input_ids, attention_mask = torch.tensor(outputs['input_ids']), torch.tensor(outputs['attention_mask'])

In [4]:
input_ids.size(), attention_mask.size()

(torch.Size([1, 27]), torch.Size([1, 27]))

Encode the text

In [5]:
# Input the data to the text encoder
encoded = model.text_encoder(input_ids)

In [6]:
encoded

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 2.2239e-01, -4.1488e-02, -8.3079e-02,  ..., -5.2557e-02,
          -9.2018e-01, -1.8920e-01],
         [ 1.7860e-01,  1.7744e-02, -1.4980e-01,  ..., -8.8605e-02,
          -3.4882e-01, -3.2289e-01],
         [-1.2259e-01, -1.4499e-01,  1.6029e-01,  ..., -3.7012e-01,
          -5.4097e-01, -3.8603e-01],
         ...,
         [-5.0348e-01,  6.1504e-01, -1.8940e-01,  ..., -3.5043e-01,
           7.3551e-02, -3.9094e-03],
         [-4.5106e-01, -1.5706e-01, -2.3703e-01,  ..., -1.9732e-01,
           1.8915e-01,  7.9153e-02],
         [ 1.4276e-02,  1.0114e-02,  8.1314e-04,  ...,  3.8193e-03,
           1.5045e-02,  1.0282e-02]]], grad_fn=<MulBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

Prepare the encoded text

In [4]:
# Apply the guidance scale or something
# When the guidance scale is > 1 then we need to apply zeros to the mask and the last hidden state
encoded.last_hidden_state = torch.concatenate([encoded.last_hidden_state, torch.zeros_like(encoded.last_hidden_state)], dim=0)
attention_mask = torch.concatenate([attention_mask, torch.zeros_like(attention_mask)], dim=0)

Prepare for the decoder

In [5]:
# Prepare for decoder inputs
num_codebooks = model.config.to_dict()['decoder']['num_codebooks']
decoder_input_ids = torch.ones((input_ids.size(0) * num_codebooks, 1), dtype=torch.long) * model.generation_config.decoder_start_token_id

# Build delay pattern
decoder_input_ids = decoder_input_ids.reshape(-1, num_codebooks, decoder_input_ids.shape[-1])
bsz, num_codebooks, seq_len = decoder_input_ids.shape
channel_codebooks = num_codebooks // 2
decoder_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long) * -1

# Just remember this as when the user wants a really small sample 
# Really small is max_len < (2 * num_codebooks - 1)
# decoder_input_ids.reshape(bsz * num_codebooks, -1), decoder_ids_shifted.reshape(bsz * num_codebooks, -1)

# Now fill the shifted ids with the prompt
for codebook in range(channel_codebooks):
    decoder_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = decoder_input_ids[:, 2 * codebook]
    decoder_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = decoder_input_ids[:, 2 * codebook + 1]

delay_pattern = torch.triu(
    torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal = max_length - channel_codebooks + 1
)
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

mask = ~delay_pattern.to(input_ids.device)
decoder_input_ids = mask * decoder_ids_shifted + ~mask * model.generation_config.decoder_start_token_id
first_codebook_ids = decoder_input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
    first_start_id = min(start_ids)
else:
    # we have no tokens that need to be filled - return entire matrix of input ids
    first_start_id = seq_len
pattern_mask = decoder_input_ids.reshape(bsz * num_codebooks, -1)
decoder_input_ids = decoder_input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)

Prepare the logic Processor

In [6]:
model.generation_config.watermarking_config

Setup the generation type:

In [7]:
num_beams = model.config.to_dict()['decoder']['num_beams']
num_beam_groups = model.config.to_dict()['decoder']['num_beam_groups']
do_sample = model.config.to_dict()['decoder']['do_sample']
is_greedy_gen_mode = (
    (num_beams == 1)
    and (num_beam_groups == 1)
    and do_sample is False
)
is_sample_gen_mode = (
    (num_beams == 1)
    and (num_beam_groups == 1)
    and do_sample is True
)

For now we will just implement the sampler mode, greedy version will be researched later

Now we have to expand the input dims based on the number of samples we get

In [8]:
# n_samples = 1
# decoder_input_ids = decoder_input_ids.repeat_interleave(n_samples, dim=0)
# attention_mask = attention_mask.repeat_interleave(n_samples, dim=0)
# pattern_mask = pattern_mask.repeat_interleave(n_samples, dim=0)

Prepare inputs for generation fn

In [9]:
# Apply mask
decoder_input_ids = torch.where(pattern_mask[..., :decoder_input_ids.shape[-1]] == -1, decoder_input_ids, pattern_mask[..., :decoder_input_ids.shape[-1]])

# Prep for the CGF
decoder_input_ids = decoder_input_ids.repeat((2,1))

In [10]:
encoder_hidden_states = encoded[0]

if (model.text_encoder.config.hidden_size != model.decoder.config.hidden_size 
    and model.decoder.config.cross_attention_hidden_size is None):
        encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states) # We need to export this aswell !!!

encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]

Do Sample

In [11]:
decoder_input_ids = decoder_input_ids
decoder_attention_mask = None
encoder_hidden_states = encoder_hidden_states
attention_mask = attention_mask
decoder_inputs_embeds = None
output_attentions = False
output_hidden_states = False
use_cache = True
past_key_values = None
return_dict = True
labels = None
head_mask = None

model.decoder(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=attention_mask,
    inputs_embeds=decoder_inputs_embeds,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    use_cache=use_cache,
    past_key_values=past_key_values,
    return_dict=return_dict,
    labels=labels,
    head_mask=head_mask
)

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-0.9733, -2.9774, -4.2276,  ...,  0.0935, -3.2401, -4.0112]],

        [[-0.9868, -2.8926, -4.3329,  ...,  0.1551, -3.6115, -4.0284]],

        [[-0.4077, -3.0823,  0.3101,  ..., -2.1532,  0.2348,  0.5535]],

        ...,

        [[-2.2284,  1.4892, -3.0453,  ..., -1.4783, -0.2749, -3.8544]],

        [[ 3.3329,  0.5850,  2.3480,  ...,  0.1123, -1.5855,  2.0758]],

        [[ 1.5364,  0.9343,  3.2742,  ..., -1.2733, -1.1586,  3.4866]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[ 5.1178e-01,  2.7520e-01, -8.8865e-01,  ...,  4.5323e-01,
            3.4695e-01, -1.7394e+00]],

         [[-1.2855e+00,  2.9377e-01,  1.4516e+00,  ...,  8.0278e-01,
            3.0280e-01,  1.6564e-02]],

         [[ 4.0691e-01,  1.2419e+00,  1.2989e+00,  ..., -6.5682e-01,
            3.9645e-01, -3.5813e-01]],

         ...,

         [[ 3.2502e-01,  6.5805e-01,  2.8554e-01,  ...,  2.5637e-01,
           -1.9514e-03, -9.3271e-0

In [12]:
model.decoder(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=attention_mask,
    inputs_embeds=decoder_inputs_embeds,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    use_cache=use_cache,
    past_key_values=past_key_values,
    return_dict=return_dict,
    labels=labels,
    head_mask=head_mask
)

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-0.9733, -2.9774, -4.2276,  ...,  0.0935, -3.2401, -4.0112]],

        [[-0.9868, -2.8926, -4.3329,  ...,  0.1551, -3.6115, -4.0284]],

        [[-0.4077, -3.0823,  0.3101,  ..., -2.1532,  0.2348,  0.5535]],

        ...,

        [[-2.2284,  1.4892, -3.0453,  ..., -1.4783, -0.2749, -3.8544]],

        [[ 3.3329,  0.5850,  2.3480,  ...,  0.1123, -1.5855,  2.0758]],

        [[ 1.5364,  0.9343,  3.2742,  ..., -1.2733, -1.1586,  3.4866]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[ 5.1178e-01,  2.7520e-01, -8.8865e-01,  ...,  4.5323e-01,
            3.4695e-01, -1.7394e+00]],

         [[-1.2855e+00,  2.9377e-01,  1.4516e+00,  ...,  8.0278e-01,
            3.0280e-01,  1.6564e-02]],

         [[ 4.0691e-01,  1.2419e+00,  1.2989e+00,  ..., -6.5682e-01,
            3.9645e-01, -3.5813e-01]],

         ...,

         [[ 3.2502e-01,  6.5805e-01,  2.8554e-01,  ...,  2.5637e-01,
           -1.9514e-03, -9.3271e-0

Decode the data

In [13]:
num_beams = model.config.to_dict()['decoder']['num_beams']
num_beam_groups = model.config.to_dict()['decoder']['num_beam_groups']
do_sample = model.config.to_dict()['decoder']['do_sample']
is_greedy_gen_mode = (
    (num_beams == 1)
    and (num_beam_groups == 1)
    and do_sample is False
)
is_sample_gen_mode = (
    (num_beams == 1)
    and (num_beam_groups == 1)
    and do_sample is True
)

### Following are just test code

In [None]:
model.generation_config.do_sample

In [None]:
decoder_input_ids

In [None]:
model.generation_config.decoder_start_token_id, model.generation_config.max_length

In [None]:
input_ids.size(), attention_mask.size()

In [None]:
model.generation_config.bos_token_id

In [None]:
input_ids.shape, attention_mask.shape, encoded[0].shape

In [None]:
decoder_ins = input_ids.unsqueeze(1).repeat((1,8,1))

In [None]:
model.decoder(decoder_ins, attention_mask)

In [None]:
encoded[0]

In [None]:
input_ids.size()

In [None]:
model.text_encoder(input_ids=input_ids)[0].size()

lets try an run it without the model

In [None]:
import torch

# Create dummy input data for ONNX export
# Adjust input_ids shape and other inputs as per the actual model's input requirements
dummy_input_ids = torch.randint(0, 10, (1, 16), dtype=torch.long)  # Example shape (batch_size=1, seq_len=16)

# Export the model to ONNX format
torch.onnx.export(
    model.text_encoder,  # The model to be exported
    dummy_input_ids,  # Example inputs for the model
    f"{folder}/text_encoder.onnx",  # The path where the ONNX model will be saved
    input_names=["input_ids", "attention_mask", "encoder_hidden_states"],  # Input names
    output_names=["logits"],  # Output name(s)
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},  # Dynamic axes for varying batch_size and seq_len
        "logits": {0: "batch_size", 1: "sequence_length"}
    },
    opset_version=13  # Use the appropriate ONNX opset version
)

print("Model exported to ONNX successfully!")


In [None]:
model.text_encoder(input_ids)

In [None]:
inputs

In [None]:
model.decoder

In [None]:
model.audio_encoder

In [None]:
model.text_encoder

Then export the local model

In [1]:
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import MusicgenOnnxConfig
from transformers import MusicgenConfig

model_id = "facebook/musicgen-small"

main_export(
    model_id,
    output="musicgen-small",
    task='text-to-audio',
)

  from .autonotebook import tqdm as notebook_tqdm
Framework not specified. Using pt to export the model.
  return func(*args, **kwargs)


KeyboardInterrupt: 

Make it efficient

In [None]:
!optimum-cli onnxruntime quantize --avx512 --onnx_model musicgen-stereo -o quantized_musicgen

## Testing the model

Load the other configs

In [None]:
import onnxruntime as ort
import json

# Load the ORT config
with open("./quantized_musicgen/ort_config.json", "r") as f:
    ort_config = json.load(f)

# Apply ORT configuration when initializing the session
session_options = ort.SessionOptions()
if "graph_optimization_level" in ort_config:
    session_options.graph_optimization_level = ort_config["graph_optimization_level"]

# Example: Setting execution providers, thread counts, etc.
if "execution_providers" in ort_config:
    session_options.execution_mode = ort_config["execution_providers"]

Load the tokenizer

In [None]:
from transformers import PreTrainedTokenizerFast, AddedToken

# Load tokenizer configuration and special tokens map
with open("./quantized_musicgen/tokenizer_config.json", "r") as f:
    tokenizer_config = json.load(f)

with open("./quantized_musicgen/special_tokens_map.json", "r") as f:
    special_tokens_map = json.load(f)
    for key, value in special_tokens_map.items():
        if key != 'additional_special_tokens':
            special_tokens_map[key] = AddedToken(
                content = value['content'], 
                single_word = value['single_word'], 
                lstrip = value['lstrip'], 
                rstrip = value['rstrip'], 
                special = True, 
                normalized = value['normalized']
            )

# Load the model configuration (config.json)
with open("./quantized_musicgen/config.json", "r") as f:
    model_config = json.load(f)

# Load the tokenizer with configuration
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./quantized_musicgen/tokenizer.json")

# Add the special tokens from the special_tokens_map.json
tokenizer.add_special_tokens(special_tokens_map)

# Configure tokenizer with settings from tokenizer_config.json
if "padding_side" in tokenizer_config:
    print('adding padding_side')
    tokenizer.padding_side = tokenizer_config["padding_side"]
if "truncation_side" in tokenizer_config:
    print('adding truncation_side')
    tokenizer.truncation_side = tokenizer_config["truncation_side"]

Load the model slices

In [None]:
text_encoder_session = ort.InferenceSession('./quantized_musicgen/text_encoder_quantized.onnx', sess_options=session_options)
decoder_session = ort.InferenceSession('./quantized_musicgen/decoder_model_quantized.onnx', sess_options=session_options)

In [None]:
input_text = "80s pop track with bassy drums and synth"
inputs = tokenizer(input_text, return_tensors="np")

In [None]:
# Run inference for text encoding
encoded_text = text_encoder_session.run(None, {
    'input_ids': inputs['input_ids'],
    'attention_mask': inputs['attention_mask']
})

In [None]:
import numpy as np
np.repeat(inputs['input_ids'], repeats=4, axis=0)

In [None]:
model_config['decoder']

In [None]:
# Process output and run decoder (adjusted based on model config)
decoder_inputs = {
    'input_ids': np.repeat(inputs['input_ids'], repeats=4, axis=0),
    'encoder_hidden_states': encoded_text[0],
    'encoder_attention_mask': inputs['attention_mask']
}

# Generate output from the decoder
decoder_output = decoder_session.run(None, decoder_inputs)

In [None]:
import os
os.listdir('./quantized_musicgen')

In [None]:
import numpy as np

# Number of decoder layers (in your case, 24 for Musicgen)
num_layers = 24

# Assuming hidden_size is the dimension of the model (1024 for Musicgen)
hidden_size = 1024

# Batch size, number of heads, sequence length (1 for the first step), and attention head size
batch_size = 1
num_heads = 16  # This depends on your model configuration
sequence_length = 1
head_size = hidden_size // num_heads

# Create past_key_values as a list of zero tensors for each layer
past_key_values = []

for _ in range(num_layers):
    decoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    decoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    past_key_values.append({
        "decoder.key": decoder_key,
        "decoder.value": decoder_value,
        "encoder.key": encoder_key,
        "encoder.value": encoder_value,
    })

In [None]:
encoder_hidden_states[:,:3,:].shape

In [None]:
input_tokens['attention_mask']

In [None]:
# Initialize variables
generated_tokens = decoder_input_ids
use_cache_branch = np.array([False], dtype=bool)  # Use False for first step

for step in range(gen_config.max_length):
    # Prepare the input dictionary for the ONNX session
    inputs = {
        "input_ids": decoder_input_ids,
        "encoder_hidden_states": encoder_hidden_states,
        "encoder_attention_mask": input_tokens['attention_mask'],
        "use_cache_branch": np.array([False], dtype=bool),  # Set to True to use past key values
    }

    # Add past key values to the input
    for i, layer_past in enumerate(past_key_values):
        inputs[f"past_key_values.{i}.decoder.key"] = layer_past["decoder.key"]
        inputs[f"past_key_values.{i}.decoder.value"] = layer_past["decoder.value"]
        inputs[f"past_key_values.{i}.encoder.key"] = layer_past["encoder.key"]
        inputs[f"past_key_values.{i}.encoder.value"] = layer_past["encoder.value"]

    # Run the ONNX session
    decoder_outputs = decoder_session.run(None, inputs)

    
    # Get logits and past key values
    logits = decoder_outputs[0]
    # Extract past_key_values from decoder_outputs if they are present
    
    # Sample next token (using greedy search, beam search, or sampling)
    next_token_id = np.argmax(logits[:, -1, :], axis=-1).reshape(4, 1)
    
    # Append the next token to generated tokens
    generated_tokens = np.concatenate([generated_tokens, next_token_id], axis=1)
    
    # Update inputs for next step
    use_cache_branch = np.array([True], dtype=bool)
    # Update past_key_values for next step


In [None]:
# Prepare input for encodec decoder
encodec_inputs = {
    "codes": generated_tokens  # Ensure this matches the expected input shape
}

# Run the encodec decoder
audio_outputs = encodec_decoder_session.run(None, encodec_inputs)

# Get the audio waveform
audio_waveform = audio_outputs[0]  # Adjust index based on actual output

In [None]:
import soundfile as sf

sf.write('generated_audio.wav', audio_waveform.squeeze(), samplerate=gen_config.sampling_rate)

In [None]:
for input_meta in decoder_session.get_inputs():
    print(f"Input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")

In [None]:
matrix = np.zeros((len(vecs), len(vecs[0])))
for i in range(len(vecs)):
    matrix[i, :] = vecs[i]
matrix = np.dot(matrix,matrix.T)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))

In [None]:
dfmax, dfmin = matrix.max(), matrix.min()

matrix = (matrix - dfmin)/(dfmax - dfmin)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))