In [1]:
! pip install gsutil
! pip install pyfluidsynth
! pip install mido


import sys

sys.path.append("mt3")
sys.path.append("t5x")
sys.path.append("airio")

# Install system dependencies

# Commented out to prevent asking for password
# ! sudo apt-get update
# ! sudo apt-get install -y libfluidsynth3 build-essential libasound2-dev libjack-dev

# Clone and install MT3 and T5X

! cd t5x && pip install .

! cd mt3 && pip install .

! cd airio && pip install .

! python3 -m pip install jax nest-asyncio pyfluidsynth==1.3.0 -e .


# Download checkpoints and soundfonts
! /home/mikea/.local/bin/gsutil -q -m cp -r gs://mt3/checkpoints .
! /home/mikea/.local/bin/gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

Defaulting to user installation because normal site-packages is not writeable
Processing /mnt/c/Users/mikeani/Documents/GitHub/gen-ai-mp3-to-musescore/python/t5x
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting airio@ git+https://github.com/google/airio#egg=airio
  Cloning https://github.com/google/airio to /tmp/pip-install-8nkq4fc9/airio_afbc1af731364bc3840a58f350be8347
  Running command git clone --filter=blob:none --quiet https://github.com/google/airio /tmp/pip-install-8nkq4fc9/airio_afbc1af731364bc3840a58f350be8347
  Resolved https://github.com/google/airio to commit d08752ea77d4e7352b8f84d1b228a279fcaadaf4
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting clu@ git+https://github.com/google/CommonLoopUtils#egg=clu
  Cloning https://github.com/google/CommonLoopUtils to /tmp/pip-install-8nkq4fc9/clu_5f8f4bb9aa894b58b48e9d69ccd83177
  Running command git clone --filter=blob:none --quiet https://github.com/google/CommonLoopUtils /tmp/pip-install-8nkq4fc9/c

In [24]:
! pip install tensorflow
! pip install jax
! pip install librosa
! pip install note-seq
! pip install gin-config
! pip install t5
! pip install t5x
! pip install nest_asyncio
! pip install mir_eval
! pip install cached_property

import functools
import os
import numpy as np
import tensorflow.compat.v2 as tf
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x
from mt3 import metrics_utils, models, network, note_sequences, preprocessors, spectrograms, vocabularies
import nest_asyncio

nest_asyncio.apply()

SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

def load_audio(file_path, sample_rate):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File {file_path} does not exist.")
    audio, sr = librosa.load(file_path, sr=sample_rate)
    return audio

class InferenceModel(object):
    """Wrapper of T5X model for music transcription."""

    def __init__(self, checkpoint_path, model_type='mt3'):

        # Model Constants.
        if model_type == 'ismir2021':
            num_velocity_bins = 127
            self.encoding_spec = note_sequences.NoteEncodingSpec
            self.inputs_length = 512
        elif model_type == 'mt3':
            num_velocity_bins = 1
            self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
            self.inputs_length = 256
        else:
            raise ValueError('Unknown model_type: %s' % model_type)

        gin_files = [
            'mt3/mt3/gin/model.gin',
            f'mt3/mt3/gin/mt3.gin'
        ]

        self.batch_size = 8
        self.outputs_length = 1024
        self.sequence_length = {'inputs': self.inputs_length, 'targets': self.outputs_length}

        self.partitioner = t5x.partitioning.PjitPartitioner(num_partitions=1)

        # Build Codecs and Vocabularies.
        self.spectrogram_config = spectrograms.SpectrogramConfig()
        self.codec = vocabularies.build_codec(
            vocab_config=vocabularies.VocabularyConfig(num_velocity_bins=num_velocity_bins)
        )
        self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
        self.output_features = {
            'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
            'targets': seqio.Feature(vocabulary=self.vocabulary),
        }

        # Create a T5X model.
        self._parse_gin(gin_files)
        self.model = self._load_model()

        # Restore from checkpoint.
        self.restore_from_checkpoint(checkpoint_path)

    @property
    def input_shapes(self):
        return {
            'encoder_input_tokens': (self.batch_size, self.inputs_length),
            'decoder_input_tokens': (self.batch_size, self.outputs_length)
        }

    def _parse_gin(self, gin_files):
        """Parse gin files used to train the model."""
        gin_bindings = [
            'from __gin__ import dynamic_registration',
            'from mt3 import vocabularies',
            'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
            'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
        ]
        with gin.unlock_config():
            gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)

    def _load_model(self):
        """Load up a T5X `Model` after parsing training gin config."""
        model_config = gin.get_configurable(network.T5Config)()
        module = network.Transformer(config=model_config)
        return models.ContinuousInputsEncoderDecoderModel(
            module=module,
            input_vocabulary=self.output_features['inputs'].vocabulary,
            output_vocabulary=self.output_features['targets'].vocabulary,
            optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
            input_depth=spectrograms.input_depth(self.spectrogram_config)
        )

    def restore_from_checkpoint(self, checkpoint_path):
        """Restore training state from checkpoint, resets self._predict_fn()."""
        train_state_initializer = t5x.utils.TrainStateInitializer(
            optimizer_def=self.model.optimizer_def,
            init_fn=self.model.get_initial_variables,
            input_shapes=self.input_shapes,
            partitioner=self.partitioner
        )

        restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
            path=checkpoint_path, mode='specific', dtype='float32'
        )

        train_state_axes = train_state_initializer.train_state_axes
        self._predict_fn = self._get_predict_fn(train_state_axes)
        self._train_state = train_state_initializer.from_checkpoint_or_scratch(
            [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)
        )

    @functools.lru_cache()
    def _get_predict_fn(self, train_state_axes):
        """Generate a partitioned prediction function for decoding."""
        def partial_predict_fn(params, batch, decode_rng):
            return self.model.predict_batch_with_aux(
                params, batch, decoder_params={'decode_rng': None}
            )
        return self.partitioner.partition(
            partial_predict_fn,
            in_axis_resources=(
                train_state_axes.params,
                t5x.partitioning.PartitionSpec('data',), None),
            out_axis_resources=t5x.partitioning.PartitionSpec('data',)
        )

    def predict_tokens(self, batch, seed=0):
        """Predict tokens from preprocessed dataset batch."""
        prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed))
        return self.vocabulary.decode_tf(prediction).numpy()

    def __call__(self, audio):
        """Infer note sequence from audio samples.

        Args:
            audio: 1-d numpy array of audio samples (16kHz) for a single example.

        Returns:
            A note_sequence of the transcribed audio.
        """
        ds = self.audio_to_dataset(audio)
        ds = self.preprocess(ds)

        model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
            ds, task_feature_lengths=self.sequence_length
        )
        model_ds = model_ds.batch(self.batch_size)

        inferences = (tokens for batch in model_ds.as_numpy_iterator() for tokens in self.predict_tokens(batch))

        predictions = []
        for example, tokens in zip(ds.as_numpy_iterator(), inferences):
            predictions.append(self.postprocess(tokens, example))

        result = metrics_utils.event_predictions_to_ns(
            predictions, codec=self.codec, encoding_spec=self.encoding_spec
        )
        return result['est_ns']

    def audio_to_dataset(self, audio):
        """Create a TF Dataset of spectrograms from input audio."""
        frames, frame_times = self._audio_to_frames(audio)
        return tf.data.Dataset.from_tensors({
            'inputs': frames,
            'input_times': frame_times,
        })

    def _audio_to_frames(self, audio):
        """Compute spectrogram frames from audio."""
        frame_size = self.spectrogram_config.hop_width
        padding = [0, frame_size - len(audio) % frame_size]
        audio = np.pad(audio, padding, mode='constant')
        frames = spectrograms.split_audio(audio, self.spectrogram_config)
        num_frames = len(audio) // frame_size
        times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
        return frames, times

    def preprocess(self, ds):
        pp_chain = [
            functools.partial(
                t5.data.preprocessors.split_tokens_to_inputs_length,
                sequence_length=self.sequence_length,
                output_features=self.output_features,
                feature_key='inputs',
                additional_feature_keys=['input_times']
            ),
            preprocessors.add_dummy_targets,
            functools.partial(
                preprocessors.compute_spectrograms,
                spectrogram_config=self.spectrogram_config
            )
        ]
        for pp in pp_chain:
            ds = pp(ds)
        return ds

    def postprocess(self, tokens, example):
        tokens = self._trim_eos(tokens)
        start_time = example['input_times'][0]
        start_time -= start_time % (1 / self.codec.steps_per_second)  # Round down to nearest symbolic token step.
        return {
            'est_tokens': tokens,
            'start_time': start_time,
            'raw_inputs': []  # Internal MT3 code expects raw inputs, not used here.
        }

    @staticmethod
    def _trim_eos(tokens):
        tokens = np.array(tokens, np.int32)
        if vocabularies.DECODED_EOS_ID in tokens:
            tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
        return tokens


  pid, fd = os.forkpty()


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Collecting protobuf>=4.21.2
  Using cached protobuf-5.27.3-cp38-abi3-manylinux2014_x86_64.whl (309 kB)
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.20.3
    Uninstalling protobuf-3.20.3:
      Successfully uninstalled protobuf-3.20.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.17.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 5.27.3 which is incompatible.
tensorflow-metadata 1.15.0 requires proto

In [23]:
checkpoint_path = f'checkpoints/mt3'

inference_model = InferenceModel(checkpoint_path, "mt3")

In [36]:
bass_audio = load_audio("../outputs/Beginning/bass.wav", SAMPLE_RATE)
drums_audio = load_audio("../outputs/Beginning/drums.wav", SAMPLE_RATE)
other_audio = load_audio("../outputs/Beginning/other.wav", SAMPLE_RATE)
vocals_audio = load_audio("../outputs/Beginning/vocals.wav", SAMPLE_RATE)

bass_ns = inference_model(bass_audio)
drums_ns = inference_model(drums_audio)
other_ns = inference_model(other_audio)
vocals_ns = inference_model(vocals_audio)

note_seq.sequence_proto_to_midi_file(bass_ns, '../outputs/Beginning/bass.mid')
note_seq.sequence_proto_to_midi_file(drums_ns, '../outputs/Beginning/drums.mid')
note_seq.sequence_proto_to_midi_file(other_ns, '../outputs/Beginning/other.mid')
note_seq.sequence_proto_to_midi_file(vocals_ns, '../outputs/Beginning/vocals.mid')

2024-08-21 12:18:20.363361: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [None]:
vocals_audio = load_audio("../inputs/Beginning.wav", SAMPLE_RATE)