Trick to make the model actually function

In [8]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, MusicgenModel

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small")
x = model.config.to_dict()
x['decoder']['num_codebooks'] = 4
model.config = model.config.from_dict(x)
model.save_pretrained("musicgen_fixed")

In [None]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration

processor = AutoProcessor.from_pretrained("facebook/musicgen-stereo-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

inputs = processor(
    text=["80s pop track with bassy drums and synth"],
    padding=True,
    return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

Then export the local model

In [10]:
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import MusicgenOnnxConfig
from transformers import MusicgenConfig

model_id = "musicgen_fixed"

main_export(
    model_id,
    output="musicgen-stereo",
    task='text-to-audio'
)

Framework not specified. Using pt to export the model.
  return func(*args, **kwargs)
Some weights of the model checkpoint at musicgen_fixed were not used when initializing MusicgenForConditionalGeneration: ['decoder.lm_heads.4.weight', 'decoder.lm_heads.5.weight', 'decoder.lm_heads.6.weight', 'decoder.lm_heads.7.weight', 'decoder.model.decoder.embed_tokens.4.weight', 'decoder.model.decoder.embed_tokens.5.weight', 'decoder.model.decoder.embed_tokens.6.weight', 'decoder.model.decoder.embed_tokens.7.weight']
- This IS expected if you are initializing MusicgenForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MusicgenForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassificati

Make it efficient

In [11]:
!optimum-cli onnxruntime quantize --avx512 --onnx_model musicgen-stereo -o quantized_musicgen

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Creating dynamic quantizer: QOperator (mode: IntegerOps, schema: u8/s8, channel-wise: False)
Quantizing model...
  elem_type: 7
  shape {
    dim {
      dim_value: 2
    }
    dim {
      dim_param: "unk__121"
    }
  }
}
.
Saving quantized model at: quantized_musicgen (external data format: False)
Configuration saved in quantized_musicgen/ort_config.json
Creating dynamic quantizer: QOperator (mode: IntegerOps, schema: u8/s8, channel-wise: False)
Quantizing model...
  elem_type: 7
  shape {
    dim {
      dim_param: "batch_size"
    }
    dim {
      dim_value: 4
    }
    dim {
      dim_param: "chunk_length"
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "unk__3"
    }
    dim {
      dim_value: 2
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "unk__26"
    }
    dim {
      dim_value: 2
    }
  }
}
.
  elem_type: 7
  shape {
    dim {
      dim_param: "unk__38"
    }
    dim {
      dim_value: 2
    }
  }
}
.
  elem_type: 7
  shape {
    dim {

Load the tokenizer

In [9]:
import onnxruntime as ort
import json

# Load the ORT config
with open("./quantized_musicgen/ort_config.json", "r") as f:
    ort_config = json.load(f)

# Apply ORT configuration when initializing the session
session_options = ort.SessionOptions()
if "graph_optimization_level" in ort_config:
    session_options.graph_optimization_level = ort_config["graph_optimization_level"]

# Example: Setting execution providers, thread counts, etc.
if "execution_providers" in ort_config:
    session_options.execution_mode = ort_config["execution_providers"]

FileNotFoundError: [Errno 2] No such file or directory: './quantized_musicgen/ort_config.json'

In [None]:
from transformers import PreTrainedTokenizerFast, AddedToken

# Load tokenizer configuration and special tokens map
with open("./quantized_musicgen/tokenizer_config.json", "r") as f:
    tokenizer_config = json.load(f)

with open("./quantized_musicgen/special_tokens_map.json", "r") as f:
    special_tokens_map = json.load(f)
    for key, value in special_tokens_map.items():
        if key != 'additional_special_tokens':
            special_tokens_map[key] = AddedToken(
                content = value['content'], 
                single_word = value['single_word'], 
                lstrip = value['lstrip'], 
                rstrip = value['rstrip'], 
                special = True, 
                normalized = value['normalized']
            )

# Load the model configuration (config.json)
with open("./quantized_musicgen/config.json", "r") as f:
    model_config = json.load(f)

# Load the tokenizer with configuration
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./quantized_musicgen/tokenizer.json")

# Add the special tokens from the special_tokens_map.json
tokenizer.add_special_tokens(special_tokens_map)

# Configure tokenizer with settings from tokenizer_config.json
if "padding_side" in tokenizer_config:
    print('adding padding_side')
    tokenizer.padding_side = tokenizer_config["padding_side"]
if "truncation_side" in tokenizer_config:
    print('adding truncation_side')
    tokenizer.truncation_side = tokenizer_config["truncation_side"]

Load the model slices

In [None]:
text_encoder_session = ort.InferenceSession('./quantized_musicgen/text_encoder_quantized.onnx', sess_options=session_options)
decoder_session = ort.InferenceSession('./quantized_musicgen/decoder_model_quantized.onnx', sess_options=session_options)

In [None]:
input_text = "80s pop track with bassy drums and synth"
inputs = tokenizer(input_text, return_tensors="np")

In [None]:
# Run inference for text encoding
encoded_text = text_encoder_session.run(None, {
    'input_ids': inputs['input_ids'],
    'attention_mask': inputs['attention_mask']
})

In [None]:
import numpy as np
np.repeat(inputs['input_ids'], repeats=4, axis=0)

In [None]:
model_config['decoder']

In [None]:
# Process output and run decoder (adjusted based on model config)
decoder_inputs = {
    'input_ids': np.repeat(inputs['input_ids'], repeats=4, axis=0),
    'encoder_hidden_states': encoded_text[0],
    'encoder_attention_mask': inputs['attention_mask']
}

# Generate output from the decoder
decoder_output = decoder_session.run(None, decoder_inputs)

In [None]:
import os
os.listdir('./quantized_musicgen')

In [None]:
import numpy as np

# Number of decoder layers (in your case, 24 for Musicgen)
num_layers = 24

# Assuming hidden_size is the dimension of the model (1024 for Musicgen)
hidden_size = 1024

# Batch size, number of heads, sequence length (1 for the first step), and attention head size
batch_size = 1
num_heads = 16  # This depends on your model configuration
sequence_length = 1
head_size = hidden_size // num_heads

# Create past_key_values as a list of zero tensors for each layer
past_key_values = []

for _ in range(num_layers):
    decoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    decoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_key = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    encoder_value = np.zeros((batch_size, num_heads, sequence_length, head_size), dtype=np.float32)
    past_key_values.append({
        "decoder.key": decoder_key,
        "decoder.value": decoder_value,
        "encoder.key": encoder_key,
        "encoder.value": encoder_value,
    })

In [None]:
encoder_hidden_states[:,:3,:].shape

In [None]:
input_tokens['attention_mask']

In [None]:
# Initialize variables
generated_tokens = decoder_input_ids
use_cache_branch = np.array([False], dtype=bool)  # Use False for first step

for step in range(gen_config.max_length):
    # Prepare the input dictionary for the ONNX session
    inputs = {
        "input_ids": decoder_input_ids,
        "encoder_hidden_states": encoder_hidden_states,
        "encoder_attention_mask": input_tokens['attention_mask'],
        "use_cache_branch": np.array([False], dtype=bool),  # Set to True to use past key values
    }

    # Add past key values to the input
    for i, layer_past in enumerate(past_key_values):
        inputs[f"past_key_values.{i}.decoder.key"] = layer_past["decoder.key"]
        inputs[f"past_key_values.{i}.decoder.value"] = layer_past["decoder.value"]
        inputs[f"past_key_values.{i}.encoder.key"] = layer_past["encoder.key"]
        inputs[f"past_key_values.{i}.encoder.value"] = layer_past["encoder.value"]

    # Run the ONNX session
    decoder_outputs = decoder_session.run(None, inputs)

    
    # Get logits and past key values
    logits = decoder_outputs[0]
    # Extract past_key_values from decoder_outputs if they are present
    
    # Sample next token (using greedy search, beam search, or sampling)
    next_token_id = np.argmax(logits[:, -1, :], axis=-1).reshape(4, 1)
    
    # Append the next token to generated tokens
    generated_tokens = np.concatenate([generated_tokens, next_token_id], axis=1)
    
    # Update inputs for next step
    use_cache_branch = np.array([True], dtype=bool)
    # Update past_key_values for next step


In [None]:
# Prepare input for encodec decoder
encodec_inputs = {
    "codes": generated_tokens  # Ensure this matches the expected input shape
}

# Run the encodec decoder
audio_outputs = encodec_decoder_session.run(None, encodec_inputs)

# Get the audio waveform
audio_waveform = audio_outputs[0]  # Adjust index based on actual output

In [None]:
import soundfile as sf

sf.write('generated_audio.wav', audio_waveform.squeeze(), samplerate=gen_config.sampling_rate)

In [None]:
for input_meta in decoder_session.get_inputs():
    print(f"Input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")

In [None]:
matrix = np.zeros((len(vecs), len(vecs[0])))
for i in range(len(vecs)):
    matrix[i, :] = vecs[i]
matrix = np.dot(matrix,matrix.T)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))

In [None]:
dfmax, dfmin = matrix.max(), matrix.min()

matrix = (matrix - dfmin)/(dfmax - dfmin)
for row in matrix:
    print(" ".join(f"{value:10.2f}" for value in row))