Trick to make the model actually function

In [80]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, MusicgenModel

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")
x = model.config.to_dict()
x['decoder']['num_codebooks'] = 4
model.config = model.config.from_dict(x)
# model.save_pretrained("musicgen_fixed")

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration

processor = AutoProcessor.from_pretrained("facebook/musicgen-stereo-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

inputs = processor(
    text=["80s pop track with bassy drums and synth"],
    padding=True,
    return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

  from .autonotebook import tqdm as notebook_tqdm
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


Then export the local model

In [None]:
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import MusicgenOnnxConfig
from transformers import MusicgenConfig

model_id = "musicgen_fixed"

main_export(
    model_id,
    output="musicgen-stereo",
    task='text-to-audio'
)

Make it efficient

In [None]:
!optimum-cli onnxruntime quantize --avx512 --onnx_model musicgen-stereo -o quantized_musicgen

Load the tokenizer

In [18]:
import onnxruntime as ort
import json

# Load the ORT config
with open("./quantized_musicgen/ort_config.json", "r") as f:
    ort_config = json.load(f)

# Apply ORT configuration when initializing the session
session_options = ort.SessionOptions()
if "graph_optimization_level" in ort_config:
    session_options.graph_optimization_level = ort_config["graph_optimization_level"]

# Example: Setting execution providers, thread counts, etc.
if "execution_providers" in ort_config:
    session_options.execution_mode = ort_config["execution_providers"]

In [28]:
from transformers import PreTrainedTokenizerFast, AddedToken

# Load tokenizer configuration and special tokens map
with open("./quantized_musicgen/tokenizer_config.json", "r") as f:
    tokenizer_config = json.load(f)

with open("./quantized_musicgen/special_tokens_map.json", "r") as f:
    special_tokens_map = json.load(f)
    for key, value in special_tokens_map.items():
        if key != 'additional_special_tokens':
            special_tokens_map[key] = AddedToken(
                content = value['content'], 
                single_word = value['single_word'], 
                lstrip = value['lstrip'], 
                rstrip = value['rstrip'], 
                special = True, 
                normalized = value['normalized']
            )

# Load the model configuration (config.json)
with open("./quantized_musicgen/config.json", "r") as f:
    model_config = json.load(f)

# Load the tokenizer with configuration
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./quantized_musicgen/tokenizer.json")

# Add the special tokens from the special_tokens_map.json
tokenizer.add_special_tokens(special_tokens_map)

# Configure tokenizer with settings from tokenizer_config.json
if "padding_side" in tokenizer_config:
    print('adding padding_side')
    tokenizer.padding_side = tokenizer_config["padding_side"]
if "truncation_side" in tokenizer_config:
    print('adding truncation_side')
    tokenizer.truncation_side = tokenizer_config["truncation_side"]

Load the model slices

In [24]:
text_encoder_session = ort.InferenceSession('./quantized_musicgen/text_encoder_quantized.onnx', sess_options=session_options)
decoder_session = ort.InferenceSession('./quantized_musicgen/decoder_model_quantized.onnx', sess_options=session_options)

In [161]:
input_text = "80s pop track with bassy drums and synth"
inputs = tokenizer(input_text, return_tensors="np")

In [162]:
# Run inference for text encoding
encoded_text = text_encoder_session.run(None, {
    'input_ids': inputs['input_ids'],
    'attention_mask': inputs['attention_mask']
})

In [159]:
import numpy as np
np.repeat(inputs['input_ids'], repeats=4, axis=0)

array([[ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,
           11, 13353,     1],
       [ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,
           11, 13353,     1],
       [ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,
           11, 13353,     1],
       [ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,
           11, 13353,     1]])

In [163]:
model_config['decoder']

{'_name_or_path': '',
 'activation_dropout': 0.0,
 'activation_function': 'gelu',
 'add_cross_attention': False,
 'architectures': None,
 'attention_dropout': 0.0,
 'audio_channels': 2,
 'bad_words_ids': None,
 'begin_suppress_tokens': None,
 'bos_token_id': 2048,
 'chunk_size_feed_forward': 0,
 'cross_attention_hidden_size': None,
 'decoder_start_token_id': None,
 'diversity_penalty': 0.0,
 'do_sample': False,
 'dropout': 0.1,
 'early_stopping': False,
 'encoder_no_repeat_ngram_size': 0,
 'eos_token_id': None,
 'exponential_decay_length_penalty': None,
 'ffn_dim': 4096,
 'finetuning_task': None,
 'forced_bos_token_id': None,
 'forced_eos_token_id': None,
 'hidden_size': 1024,
 'id2label': {'0': 'LABEL_0', '1': 'LABEL_1'},
 'initializer_factor': 0.02,
 'is_decoder': False,
 'is_encoder_decoder': False,
 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
 'layerdrop': 0.0,
 'length_penalty': 1.0,
 'max_length': 20,
 'max_position_embeddings': 2048,
 'min_length': 0,
 'model_type': 'musicgen_deco

In [156]:
# Process output and run decoder (adjusted based on model config)
decoder_inputs = {
    'input_ids': np.repeat(inputs['input_ids'], repeats=4, axis=0),
    'encoder_hidden_states': encoded_text[0],
    'encoder_attention_mask': inputs['attention_mask']
}

# Generate output from the decoder
decoder_output = decoder_session.run(None, decoder_inputs)

[1;31m2024-09-27 03:22:47.168206832 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Gather node. Name:'/decoder/model/decoder/embed_tokens.1/Gather' Status Message: indices element out of data bounds, idx=2775 must be within the inclusive range [-2049,2048][m


InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:'/decoder/model/decoder/embed_tokens.1/Gather' Status Message: indices element out of data bounds, idx=2775 must be within the inclusive range [-2049,2048]

In [None]:
import os
os.listdir('./quantized_musicgen')

In [80]:
import numpy as np

# Number of decoder layers (in your case, 24 for Musicgen)
num_layers = 24

# Assuming hidden_size is the dimension of the model (1024 for Musicgen)
hidden_size = 1024

# Batch size, number of heads, sequence length (1 for the first step), and attention head size
batch_size = 1
num_heads = 16  # This depends on your model configuration
sequence_length = 1
head_size = hidden_size // num_heads

# Create past_key_values as a list of zero tensors for each layer
past_key_values = []

for _ in range(num_layers):
    decoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    decoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    past_key_values.append({
        "decoder.key": decoder_key,
        "decoder.value": decoder_value,
        "encoder.key": encoder_key,
        "encoder.value": encoder_value,
    })

In [None]:
encoder_hidden_states[:,:3,:].shape

In [None]:
input_tokens['attention_mask']

In [144]:
# Initialize variables
generated_tokens = decoder_input_ids
use_cache_branch = np.array([False], dtype=bool)  # Use False for first step

for step in range(gen_config.max_length):
    # Prepare the input dictionary for the ONNX session
    inputs = {
        "input_ids": decoder_input_ids,
        "encoder_hidden_states": encoder_hidden_states,
        "encoder_attention_mask": input_tokens['attention_mask'],
        "use_cache_branch": np.array([False], dtype=bool),  # Set to True to use past key values
    }

    # Add past key values to the input
    for i, layer_past in enumerate(past_key_values):
        inputs[f"past_key_values.{i}.decoder.key"] = layer_past["decoder.key"]
        inputs[f"past_key_values.{i}.decoder.value"] = layer_past["decoder.value"]
        inputs[f"past_key_values.{i}.encoder.key"] = layer_past["encoder.key"]
        inputs[f"past_key_values.{i}.encoder.value"] = layer_past["encoder.value"]

    # Run the ONNX session
    decoder_outputs = decoder_session.run(None, inputs)

    
    # Get logits and past key values
    logits = decoder_outputs[0]
    # Extract past_key_values from decoder_outputs if they are present
    
    # Sample next token (using greedy search, beam search, or sampling)
    next_token_id = np.argmax(logits[:, -1, :], axis=-1).reshape(4, 1)
    
    # Append the next token to generated tokens
    generated_tokens = np.concatenate([generated_tokens, next_token_id], axis=1)
    
    # Update inputs for next step
    use_cache_branch = np.array([True], dtype=bool)
    # Update past_key_values for next step


In [None]:
# Prepare input for encodec decoder
encodec_inputs = {
    "codes": generated_tokens  # Ensure this matches the expected input shape
}

# Run the encodec decoder
audio_outputs = encodec_decoder_session.run(None, encodec_inputs)

# Get the audio waveform
audio_waveform = audio_outputs[0]  # Adjust index based on actual output

In [None]:
import soundfile as sf

sf.write('generated_audio.wav', audio_waveform.squeeze(), samplerate=gen_config.sampling_rate)

In [None]:
for input_meta in decoder_session.get_inputs():
    print(f"Input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")

In [125]:
matrix = np.zeros((len(vecs), len(vecs[0])))
for i in range(len(vecs)):
    matrix[i, :] = vecs[i]
matrix = np.dot(matrix,matrix.T)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))

     14.00       9.00      13.00      11.30       8.25      11.40      14.00      13.40      14.00      10.97       8.13      10.65
      9.00      13.00      12.00       8.25      10.15       9.85      13.00      13.00      13.00       8.14      10.41       9.35
     13.00      12.00      17.00      11.40       9.85      13.90      16.93      16.49      17.00      11.16      10.17      13.47
     11.30       8.25      11.40       9.98       7.96      10.69      13.94      13.48      13.93       9.90       7.94      10.22
      8.25      10.15       9.85       7.96       9.11       9.18      12.96      12.76      12.95       8.02       9.19       8.84
     11.40       9.85      13.90      10.69       9.18      12.79      16.89      16.42      16.89      10.70       9.38      12.46
     14.00      13.00      16.93      13.94      12.96      16.89      24.65      23.97      24.59      14.30      13.22      16.47
     13.40      13.00      16.49      13.48      12.76      16.42      23.97

In [126]:
dfmax, dfmin = matrix.max(), matrix.min()

matrix = (matrix - dfmin)/(dfmax - dfmin)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))

      0.36       0.06       0.30       0.20       0.02       0.21       0.36       0.33       0.36       0.18       0.01       0.16
      0.06       0.30       0.24       0.02       0.13       0.11       0.30       0.30       0.30       0.01       0.15       0.08
      0.30       0.24       0.54       0.21       0.11       0.36       0.54       0.51       0.54       0.19       0.13       0.33
      0.20       0.02       0.21       0.12       0.00       0.16       0.36       0.33       0.36       0.12       0.00       0.14
      0.02       0.13       0.11       0.00       0.07       0.07       0.30       0.29       0.30       0.00       0.07       0.05
      0.21       0.11       0.36       0.16       0.07       0.29       0.54       0.51       0.54       0.17       0.09       0.27
      0.36       0.30       0.54       0.36       0.30       0.54       1.00       0.96       1.00       0.38       0.32       0.51
      0.33       0.30       0.51       0.33       0.29       0.51       0.96