In [None]:
import os
import json
import IPython

def install_dependencies():
    """Install required system packages and Python dependencies."""
    os.system("apt-get update -qq && apt-get install -qq libfluidsynth3 build-essential libasound2-dev libjack-dev")
    os.system("git clone --branch=main https://github.com/magenta/mt3")
    os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
    os.system("python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")

def copy_checkpoints(checkpoint_dir='checkpoints'):
    """Copy model checkpoints from cloud storage."""
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.system(f"gsutil -q -m cp -r gs://mt3/checkpoints {checkpoint_dir}")

def copy_soundfont(soundfont_path='SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'):
    """Copy soundfont file for audio synthesis."""
    os.makedirs(os.path.dirname(soundfont_path), exist_ok=True)
    os.system(f"gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 {soundfont_path}")

def AnalyticsSetup(analytics_id='G-4P250YRJ08'):
    """Set up anonymous analytics tracking."""
    html_code = f'''
<!-- Analytics Setup -->
<script async src="https://www.googletagmanager.com/gtag/js?id={analytics_id}"></script>
<script>
  window.dataLayer = window.dataLayer || [];
  function gtag(){{dataLayer.push(arguments);}}
  gtag('js', new Date());
  gtag('config', '{analytics_id}',
       {{'referrer': document.referrer.split('?')[0],
        'anonymize_ip': true,
        'page_title': '',
        'page_referrer': '',
        'cookie_prefix': 'magenta',
        'cookie_domain': 'auto',
        'cookie_expires': 0,
        'cookie_flags': 'SameSite=None;Secure'}});
</script>
'''
    IPython.display.display(IPython.display.HTML(html_code))

def LogAnalyticsEvent(event_name, event_details):
    """Log an analytics event with specified details."""
    details_json = json.dumps(event_details)
    js_string = f"gtag('event', '{event_name}', {details_json});"
    IPython.display.display(IPython.display.Javascript(js_string))

# Original working lines
!apt-get update -qq && apt-get install -qq libfluidsynth3 build-essential libasound2-dev libjack-dev
!git clone --branch=main https://github.com/magenta/mt3
!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp
!python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!gsutil -q -m cp -r gs://mt3/checkpoints .
!gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

import json
import IPython

def load_gtag():
  """Loads gtag.js."""
  html_code = '''
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4P250YRJ08"></script>
<script>
  window.dataLayer = window.dataLayer || [];
  function gtag(){dataLayer.push(arguments);}
  gtag('js', new Date());
  gtag('config', 'G-4P250YRJ08',
       {'referrer': document.referrer.split('?')[0],
        'anonymize_ip': true,
        'page_title': '',
        'page_referrer': '',
        'cookie_prefix': 'magenta',
        'cookie_domain': 'auto',
        'cookie_expires': 0,
        'cookie_flags': 'SameSite=None;Secure'});
</script>
'''
  IPython.display.display(IPython.display.HTML(html_code))

def log_event(event_name, event_details):
  """Log event with name and details dictionary."""
  details_json = json.dumps(event_details)
  js_string = "gtag('event', '%s', %s);" % (event_name, details_json)
  IPython.display.display(IPython.display.Javascript(js_string))

load_gtag()
log_event('setupComplete', {})

In [None]:
import functools
import os
import numpy as np
import tensorflow.compat.v2 as tf
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x
from mt3 import metrics_utils, models, network, note_sequences, preprocessors, spectrograms, vocabularies
from google.colab import files
import nest_asyncio

def initialize_transcription_config(model_type='mt3', batch_size=8, inputs_length=256, outputs_length=1024):
    """Initialize configuration for transcription model."""
    nest_asyncio.apply()
    if model_type == 'ismir2021':
        num_velocity_bins = 127
        encoding_spec = note_sequences.NoteEncodingSpec
        inputs_length = 512
    elif model_type == 'mt3':
        num_velocity_bins = 1
        encoding_spec = note_sequences.NoteEncodingWithTiesSpec
    else:
        raise ValueError(f'Unknown model_type: {model_type}')

    sequence_length = {'inputs': inputs_length, 'targets': outputs_length}
    spectrogram_config = spectrograms.SpectrogramConfig()
    codec = vocabularies.build_codec(vocab_config=vocabularies.VocabularyConfig(num_velocity_bins=num_velocity_bins))
    vocabulary = vocabularies.vocabulary_from_codec(codec)
    output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=vocabulary),
    }
    return {
        'batch_size': batch_size,
        'sequence_length': sequence_length,
        'spectrogram_config': spectrogram_config,
        'codec': codec,
        'vocabulary': vocabulary,
        'output_features': output_features,
        'encoding_spec': encoding_spec
    }

def parse_gin_config(gin_files, num_velocity_bins):
    """Parse gin configuration files."""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        f'vocabularies.VocabularyConfig.num_velocity_bins={num_velocity_bins}'
    ]
    with gin.unlock_config():
        gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)

def BuildTranscriptionModel(config):
    """Build the transcription model."""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    return t5x.models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=config['output_features']['inputs'].vocabulary,
        output_vocabulary=config['output_features']['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(config['spectrogram_config'])
    )

def restore_model_from_checkpoint(model, checkpoint_path, config):
    """Restore model state from checkpoint."""
    partitioner = t5x.partitioning.PjitPartitioner(num_partitions=1)
    input_shapes = {
        'encoder_input_tokens': (config['batch_size'], config['sequence_length']['inputs']),
        'decoder_input_tokens': (config['batch_size'], config['sequence_length']['targets'])
    }
    train_state_initializer = t5x.utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        partitioner=partitioner
    )
    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(path=checkpoint_path, mode='specific', dtype='float32')
    train_state_axes = train_state_initializer.train_state_axes
    predict_fn = partitioner.partition(
        lambda params, batch, decode_rng: model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None}),
        in_axis_resources=(train_state_axes.params, t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )
    train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)
    )
    return predict_fn, train_state

def preprocess_audio_to_dataset(audio, spectrogram_config):
    """Convert audio to dataset with spectrogram frames."""
    frame_size = spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, spectrogram_config)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / spectrogram_config.frames_per_second
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': times
    })

def preprocess_dataset(ds, config):
    """Preprocess dataset with spectrogram computation."""
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=config['sequence_length'],
            output_features=config['output_features'],
            feature_key='inputs',
            additional_feature_keys=['input_times']
        ),
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=config['spectrogram_config']
        )
    ]
    for pp in pp_chain:
        ds = pp(ds)
    return ds

def predict_transcription_tokens(predict_fn, train_state, batch, seed=0):
    """Predict transcription tokens from dataset batch."""
    prediction, _ = predict_fn(train_state.params, batch, jax.random.PRNGKey(seed))
    return prediction

def postprocess_tokens(tokens, example, codec, encoding_spec):
    """Postprocess predicted tokens."""
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
        tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    start_time = example['input_times'][0]
    start_time -= start_time % (1 / codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        'raw_inputs': []
    }

def transcribe_audio(audio, predict_fn, train_state, config):
    """Transcribe audio to note sequence."""
    ds = preprocess_audio_to_dataset(audio, config['spectrogram_config'])
    ds = preprocess_dataset(ds, config)
    model_ds = config['output_features']['targets'].vocabulary.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=config['sequence_length']
    )
    model_ds = model_ds.batch(config['batch_size'])
    inferences = (config['vocabulary'].decode_tf(tokens).numpy()
                  for batch in model_ds.as_numpy_iterator()
                  for tokens in predict_transcription_tokens(predict_fn, train_state, batch))
    predictions = [postprocess_tokens(tokens, example, config['codec'], config['encoding_spec'])
                   for example, tokens in zip(ds.as_numpy_iterator(), inferences)]
    return metrics_utils.event_predictions_to_ns(predictions, codec=config['codec'], encoding_spec=config['encoding_spec'])['est_ns']

def upload_audio_file(sample_rate=16000):
    """Upload and process audio file."""
    data = list(files.upload().values())
    if len(data) > 1:
        print('Multiple files uploaded; using only one.')
    return note_seq.audio_io.wav_data_to_samples_librosa(data[0], sample_rate=sample_rate)

import functools
import os
import numpy as np
import tensorflow.compat.v2 as tf
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x
from mt3 import metrics_utils
from mt3 import models
from mt3 import network
from mt3 import note_sequences
from mt3 import preprocessors
from mt3 import spectrograms
from mt3 import vocabularies
from google.colab import files
import nest_asyncio
nest_asyncio.apply()
SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

def upload_audio(sample_rate):
  data = list(files.upload().values())
  if len(data) > 1:
    print('Multiple files uploaded; using only one.')
  return note_seq.audio_io.wav_data_to_samples_librosa(
    data[0], sample_rate=sample_rate)

class InferenceModel(object):
  """Wrapper of T5X model for music transcription."""
  def __init__(self, checkpoint_path, model_type='mt3'):
    if model_type == 'ismir2021':
      num_velocity_bins = 127
      self.encoding_spec = note_sequences.NoteEncodingSpec
      self.inputs_length = 512
    elif model_type == 'mt3':
      num_velocity_bins = 1
      self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
      self.inputs_length = 256
    else:
      raise ValueError('unknown model_type: %s' % model_type)
    self.gin_files = ['/content/mt3/gin/model.gin',  # Store gin_files as instance variable
                     f'/content/mt3/gin/{model_type}.gin']
    self.batch_size = 8
    self.outputs_length = 1024
    self.sequence_length = {'inputs': self.inputs_length,
                            'targets': self.outputs_length}
    self.partitioner = t5x.partitioning.PjitPartitioner(
        num_partitions=1)

    self.spectrogram_config = spectrograms.SpectrogramConfig()
    self.codec = vocabularies.build_codec(
        vocab_config=vocabularies.VocabularyConfig(
            num_velocity_bins=num_velocity_bins))
    self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
    self.output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=self.vocabulary),
    }
    self._parse_gin(self.gin_files)
    self.model = self._load_model()
    self.restore_from_checkpoint(checkpoint_path)
  @property
  def input_shapes(self):
    return {
          'encoder_input_tokens': (self.batch_size, self.inputs_length),
          'decoder_input_tokens': (self.batch_size, self.outputs_length)
    }
  def _parse_gin(self, gin_files):
    """Parse gin files used to train the model."""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
    ]
    with gin.unlock_config():
      gin.parse_config_files_and_bindings(
          gin_files, gin_bindings, finalize_config=False)
  def _load_model(self):
    """Load up a T5X `Model` after parsing training gin config."""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    return models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=self.output_features['inputs'].vocabulary,
        output_vocabulary=self.output_features['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(self.spectrogram_config))
  def restore_from_checkpoint(self, checkpoint_path):
    """Restore training state from checkpoint, resets self._predict_fn()."""
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=self.model.optimizer_def,
      init_fn=self.model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)
    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')
    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
  @functools.lru_cache()
  def _get_predict_fn(self, train_state_axes):
    """Generate a partitioned prediction function for decoding."""
    def partial_predict_fn(params, batch, decode_rng):
      return self.model.predict_batch_with_aux(
          params, batch, decoder_params={'decode_rng': None})
    return self.partitioner.partition(
        partial_predict_fn,
        in_axis_resources=(
            train_state_axes.params,
            t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )
  def predict_tokens(self, batch, seed=0):
    """Predict tokens from preprocessed dataset batch."""
    prediction, _ = self._predict_fn(
        self._train_state.params, batch, jax.random.PRNGKey(seed))
    return self.vocabulary.decode_tf(prediction).numpy()
  def __call__(self, audio):
    """Infer note sequence from audio samples.
    Args:
      audio: 1-d numpy array of audio samples (16kHz) for a single example.
    Returns:
      A note_sequence of the transcribed audio.
    """
    ds = self.audio_to_dataset(audio)
    ds = self.preprocess(ds)
    model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=self.sequence_length)
    model_ds = model_ds.batch(self.batch_size)
    inferences = (tokens for batch in model_ds.as_numpy_iterator()
                  for tokens in self.predict_tokens(batch))
    predictions = []
    for example, tokens in zip(ds.as_numpy_iterator(), inferences):
      predictions.append(self.postprocess(tokens, example))
    result = metrics_utils.event_predictions_to_ns(
        predictions, codec=self.codec, encoding_spec=self.encoding_spec)
    return result['est_ns']
  def audio_to_dataset(self, audio):
    """Create a TF Dataset of spectrograms from input audio."""
    frames, frame_times = self._audio_to_frames(audio)
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': frame_times,
    })
  def _audio_to_frames(self, audio):
    """Compute spectrogram frames from audio."""
    frame_size = self.spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, self.spectrogram_config)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
    return frames, times
  def preprocess(self, ds):
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=self.sequence_length,
            output_features=self.output_features,
            feature_key='inputs',
            additional_feature_keys=['input_times']),
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=self.spectrogram_config)
    ]
    for pp in pp_chain:
      ds = pp(ds)
    return ds
  def postprocess(self, tokens, example):
    tokens = self._trim_eos(tokens)
    start_time = example['input_times'][0]
    start_time -= start_time % (1 / self.codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        'raw_inputs': []
    }
  @staticmethod
  def _trim_eos(tokens):
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
      tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    return tokens

In [None]:

MODEL = "mt3"
checkpoint_path = '/content/checkpoints/mt3/'

def SetupTranscriptionModel(checkpoint_path, model_type='mt3'):
    """Set up the transcription model with checkpoint."""
    AnalyticsSetup()
    LogAnalyticsEvent('loadModelStart', {'event_category': model_type})
    config = initialize_transcription_config(model_type=model_type)
    parse_gin_config(['/content/mt3/gin/model.gin', f'/content/mt3/gin/{model_type}.gin'],
                     config['codec'].vocab_config.num_velocity_bins)
    model = BuildTranscriptionModel(config)
    predict_fn, train_state = restore_model_from_checkpoint(model, checkpoint_path, config)
    LogAnalyticsEvent('loadModelComplete', {'event_category': model_type})
    return config, model, predict_fn, train_state


load_gtag()
log_event('loadModelStart', {'event_category': MODEL})
inference_model = InferenceModel(checkpoint_path, MODEL)
log_event('loadModelComplete', {'event_category': MODEL})

In [None]:
def UploadAndProcessAudio(sample_rate=16000):
    """Upload and process audio file with logging."""
    AnalyticsSetup()
    LogAnalyticsEvent('uploadAudioStart', {})
    audio = upload_audio_file(sample_rate)
    LogAnalyticsEvent('uploadAudioComplete', {'value': round(len(audio) / sample_rate)})
    return audio

def Note_seq_play(audio, sample_rate=16000):
    """Play audio using note sequence utilities."""
    note_seq.notebook_utils.colab_play(audio, sample_rate=sample_rate)


load_gtag()
log_event('uploadAudioStart', {})
audio = upload_audio(sample_rate=SAMPLE_RATE)
log_event('uploadAudioComplete', {'value': round(len(audio) / SAMPLE_RATE)})
note_seq.notebook_utils.colab_play(audio, sample_rate=SAMPLE_RATE)

In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import soundfile
import gin
import os
import t5x
import jax
import tensorflow.compat.v2 as tf
import functools
from mt3 import network, spectrograms, preprocessors


SAMPLE_RATE = 16000
HOP_LENGTH = 512
N_MELS = 229
MEL_FMIN = 30
MEL_FMAX = SAMPLE_RATE // 2
WINDOW_LENGTH = 2048


def preprocess_audio(audio_path):
    """Load and validate audio (aligned with dataset.py)."""
    try:
        audio, sr = soundfile.read(audio_path, dtype='int16')
        assert sr == SAMPLE_RATE
        if audio.ndim == 2:
            audio = np.mean(audio, axis=1)
        audio = audio / np.max(np.abs(audio))
        audio = torch.ShortTensor(audio * 32768.0)
        return audio, len(audio), True
    except Exception as e:
        print(f"Error processing audio {audio_path}: {e}")
        return None, 0, False

def parse_gin_config(gin_files, num_velocity_bins):
    """Parse gin configuration files with num_velocity_bins."""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        f'vocabularies.VocabularyConfig.num_velocity_bins={num_velocity_bins}'
    ]
    with gin.unlock_config():
        for gin_file in gin_files:
            if not os.path.exists(gin_file):
                raise FileNotFoundError(f"Gin file not found: {gin_file}")
        gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)

def preprocess_audio_to_dataset(audio, spectrogram_config):
    """Convert audio to dataset with spectrogram frames."""
    frame_size = spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, spectrogram_config)

    if frames.ndim == 1:
        frames = np.expand_dims(frames, axis=-1)
    frames = np.expand_dims(frames, axis=0)
    if frames.shape[-1] == 1:
        frames = np.repeat(frames, N_MELS, axis=-1)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / spectrogram_config.frames_per_second
    print(f"preprocess_audio_to_dataset: frames shape = {frames.shape}")
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': times
    })

def preprocess_dataset(ds, config):
    """Preprocess dataset with spectrogram computation."""
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=config['sequence_length'],
            output_features=config['output_features'],
            feature_key='inputs',
            additional_feature_keys=['input_times']
        ),
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=config['spectrogram_config']
        )
    ]
    for pp in pp_chain:
        ds = pp(ds)

    def ensure_3d_inputs(example):
        inputs = example['inputs']
        if inputs.ndim == 2:
            inputs = tf.expand_dims(inputs, axis=-1)

        if inputs.shape[-1] != N_MELS:
            inputs = tf.image.resize(inputs[..., tf.newaxis],
                                   [inputs.shape[1], N_MELS],
                                   method='bilinear')[..., 0]
        example['inputs'] = inputs
        print(f"preprocess_dataset: inputs shape = {inputs.shape}")
        return example
    ds = ds.map(ensure_3d_inputs)
    return ds

def BuildTranscriptionModel(config):
    """Build the transcription model using a compatible T5X model."""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    try:
        return t5x.models.ContinuousInputsEncoderDecoderModel(
            module=module,
            input_vocabulary=config['output_features']['inputs'].vocabulary,
            output_vocabulary=config['output_features']['targets'].vocabulary,
            optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
            input_depth=config['spectrogram_config'].input_depth
        )
    except AttributeError:
        print("Warning: ContinuousInputsEncoderDecoderModel not found. Using EncoderDecoderModel.")
        return t5x.models.EncoderDecoderModel(
            module=module,
            input_vocabulary=config['output_features']['inputs'].vocabulary,
            output_vocabulary=config['output_features']['targets'].vocabulary,
            optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0)
        )

def SetupTranscriptionModel(checkpoint_path, model_type='mt3'):
    """Set up the transcription model with checkpoint."""
    AnalyticsSetup()
    LogAnalyticsEvent('loadModelStart', {'event_category': model_type})

    config = initialize_transcription_config(model_type=model_type)
    num_velocity_bins = 1 if model_type == 'mt3' else 127
    gin_files = ['/content/mt3/gin/model.gin', f'/content/mt3/gin/{model_type}.gin']
    parse_gin_config(gin_files, num_velocity_bins)
    model = BuildTranscriptionModel(config)
    predict_fn, train_state = restore_model_from_checkpoint(model, checkpoint_path, config)
    LogAnalyticsEvent('loadModelComplete', {'event_category': model_type})
    return config, model, predict_fn, train_state

def TranscribeAudio(audio, config, predict_fn, train_state, model_type='mt3'):
    """Transcribe audio to note sequence with logging."""
    AnalyticsSetup()
    LogAnalyticsEvent('transcribeStart', {
        'event_category': model_type,
        'value': round(len(audio) / 16000)
    })
    est_ns = transcribe_audio(audio, predict_fn, train_state, config)
    LogAnalyticsEvent('transcribeComplete', {
        'event_category': model_type,
        'value': round(len(audio) / 16000),
        'numNotes': sum(1 for note in est_ns.notes if not note.is_drum),
        'numDrumNotes': sum(1 for note in est_ns.notes if note.is_drum),
        'numPrograms': len(set(note.program for note in est_ns.notes if not note.is_drum))
    })
    return est_ns

def Note_seq_play_sequence(note_sequence, sample_rate=16000, sf2_path='SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'):
    """Play transcribed note sequence."""
    note_seq.play_sequence(note_sequence, synth=note_seq.fluidsynth, sample_rate=sample_rate, sf2_path=sf2_path)

def Note_seq_visualize(note_sequence):
    """Visualize transcribed note sequence."""
    note_seq.plot_sequence(note_sequence)



def main_transcription(audio, config, predict_fn, train_state, model_type='mt3', sample_rate=16000, sf2_path='SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'):
    """Main function to transcribe audio, display mel spectrogram, TSV, and play/visualize sequence."""
    est_ns = TranscribeAudio(audio, config, predict_fn, train_state, model_type)
    DisplayMelSpectrogram(audio, sample_rate)
    DisplayTSVContent(est_ns)
    Note_seq_play_sequence(est_ns, sample_rate, sf2_path)
    Note_seq_visualize(est_ns)
    return est_ns


load_gtag()
log_event('transcribeStart', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE)
})
est_ns = inference_model(audio)
log_event('transcribeComplete', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE),
    'numNotes': sum(1 for note in est_ns.notes if not note.is_drum),
    'numDrumNotes': sum(1 for note in est_ns.notes if note.is_drum),
    'numPrograms': len(set(note.program for note in est_ns.notes
                           if not note.is_drum))
})
note_seq.play_sequence(est_ns, synth=note_seq.fluidsynth,
                       sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
note_seq.plot_sequence(est_ns)



In [None]:
def Note_seq_to_midi(note_sequence, output_path='/tmp/transcribed.mid'):
    """Export note sequence to MIDI file with logging."""
    AnalyticsSetup()
    LogAnalyticsEvent('downloadTranscription', {
        'event_category': 'mt3',
        'value': round(len(audio) / SAMPLE_RATE),
        'numNotes': sum(1 for note in note_sequence.notes if not note.is_drum),
        'numDrumNotes': sum(1 for note in note_sequence.notes if note.is_drum),
        'numPrograms': len(set(note.program for note in note_sequence.notes if not note.is_drum))
    })
    note_seq.sequence_proto_to_midi_file(note_sequence, output_path)

def PrintHumanReadableMidi(note_sequence, time_step=0.1):
    """Print MIDI notes in human-readable format."""
    if not note_sequence.notes:
        print("No notes in the sequence.")
        return
    max_time = max(note.end_time for note in note_sequence.notes)
    time_bins = np.arange(0, max_time + time_step, time_step)
    midi_to_note = {0: 'C', 1: 'C#', 2: 'D', 3: 'D#', 4: 'E', 5: 'F',
                    6: 'F#', 7: 'G', 8: 'G#', 9: 'A', 10: 'A#', 11: 'B'}

    for t in time_bins:
        active_notes = []
        for note in note_sequence.notes:
            if note.start_time <= t < note.end_time and not note.is_drum:
                octave = (note.pitch // 12) - 1
                note_name = midi_to_note[note.pitch % 12]
                active_notes.append(f"{note_name}{octave}")
        if active_notes:
            print(f"At time {t:.2f}s: {' '.join(active_notes)}")


load_gtag()
log_event('downloadTranscription', {
    'event_category': MODEL,
    'value': round(len(audio) / SAMPLE_RATE),
    'numNotes': sum(1 for note in est_ns.notes if not note.is_drum),
    'numDrumNotes': sum(1 for note in est_ns.notes if note.is_drum),
    'numPrograms': len(set(note.program for note in est_ns.notes
                           if not note.is_drum))
})
note_seq.sequence_proto_to_midi_file(est_ns, '/tmp/transcribed.mid')


PrintHumanReadableMidi(est_ns)

In [None]:
from music21 import converter, metadata, environment
from IPython.display import Image, display
import music21

def install_lilypond():
    """Install Lilypond for music sheet generation."""
    os.system("apt-get update")
    os.system("apt-get install lilypond -y")

def configure_music21_environment():
    """Configure music21 environment for Lilypond."""
    env = environment.Environment()
    env['musicxmlPath'] = '/usr/bin/lilypond'
    env['lilypondPath'] = '/usr/bin/lilypond'
    return env

def Music21_convert(midi_path, title="Music Sheet Created with Love @SoC"):
    """Convert MIDI file to music sheet."""
    score = converter.parse(midi_path)
    score.metadata = metadata.Metadata()
    score.metadata.title = title
    return score

def DisplayMusicSheet(score, output_format='lily.png'):
    """Display music sheet from score."""
    score.write(output_format)
    display(Image(filename=score.write(output_format)))


from music21 import converter, metadata
from IPython.display import Image, display
from music21 import environment
env = environment.Environment()
env['musicxmlPath'] = '/usr/bin/lilypond'
env['lilypondPath'] = '/usr/bin/lilypond'
midi_file_path = '/tmp/transcribed.mid'
score = converter.parse(midi_file_path)
score.metadata = metadata.Metadata()
score.metadata.title = "Music Sheet Generated with Love @SoC"
score.show('lily.png')
display(Image(filename=score.write('lily.png')))

