<a href="https://colab.research.google.com/github/magenta/mt3/blob/main/mt3/colab/music_transcription_with_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Transcription with Transformers

This notebook is an interactive demo of a few [music transcription models](g.co/magenta/mt3) created by Google's [Magenta](g.co/magenta) team.  You can upload audio and have one of our models automatically transcribe it.

<img src="https://magenta.tensorflow.org/assets/transcription-with-transformers/architecture_diagram.png" alt="Transformer-based transcription architecture">

The notebook supports two pre-trained models:
1. the piano transcription model from [our ISMIR 2021 paper](https://archives.ismir.net/ismir2021/paper/000030.pdf)
1. the multi-instrument transcription model from [our ICLR 2022 paper](https://openreview.net/pdf?id=iMSjopcOn0p)

**Caveat**: neither model is trained on singing.  If you upload audio with vocals, you will likely get weird results.  Multi-instrument transcription is still not a completely-solved problem and so you may get weird results regardless.

In any case, we hope you have fun transcribing!  Feel free to tweet any interesting output at [@GoogleMagenta](https://twitter.com/googlemagenta)...

### Instructions for running:

* Make sure to use a GPU runtime, click:  __Runtime >> Change Runtime Type >> GPU__
* Press ▶️ on the left of each cell to execute the cell
* In the __Load Model__ cell, choose either `ismir2021` for piano transcription or `mt3` for multi-instrument transcription
* In the __Upload Audio__ cell, choose an MP3 or WAV file from your computer when prompted
* Transcribe the audio using the __Transcribe Audio__ cell (it may take a few minutes depending on the length of the audio)

---

This notebook sends basic usage data to Google Analytics.  For more information, see [Google's privacy policy](https://policies.google.com/privacy).

In [None]:


#@title Setup Environment
#@markdown Install MT3 and its dependencies (may take a few minutes).

!apt-get update -qq && apt-get install -qq libfluidsynth3 build-essential libasound2-dev libjack-dev

# install mt3
!git clone --branch=main https://github.com/magenta/mt3
!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp
!python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# copy checkpoints
!gsutil -q -m cp -r gs://mt3/checkpoints .

# copy soundfont (originally from https://sites.google.com/site/soundfonts4u)
!gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

import json
import IPython


In [None]:
# Import required libraries and modules for model inference, audio processing, and configuration.
import functools
import os
import numpy as np
import tensorflow.compat.v2 as tf
import functools  # (Note: This import is repeated; one could remove duplicate imports.)
import gin         # For configuration via gin files.
import jax         # For JAX-based computations.
import librosa     # For audio processing.
import note_seq    # For working with note sequences.
import seqio       # For sequence processing.
import t5          # For T5 model support.
import t5x         # For T5X model framework.

# Import specific modules from the mt3 package used for transcription.
from mt3 import metrics_utils
from mt3 import models
from mt3 import network
from mt3 import note_sequences
from mt3 import preprocessors
from mt3 import spectrograms
from mt3 import vocabularies

# For file uploading in Colab
from google.colab import files

# Patch the event loop to support asynchronous operations in notebooks.
import nest_asyncio
nest_asyncio.apply()

# Define global constants for the audio sample rate and the soundfont file path.
SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

def upload_audio(sample_rate):
  """
  Prompt the user to upload audio files using Google Colab's file upload dialog.
  Converts the first uploaded file's raw audio data into samples using note_seq.
  """
  data = list(files.upload().values())
  if len(data) > 1:
    print('Multiple files uploaded; using only one.')
  # Convert raw audio data to waveform samples.
  return note_seq.audio_io.wav_data_to_samples_librosa(
    data[0], sample_rate=sample_rate)

class InferenceModel(object):
  """Wrapper of T5X model for music transcription.

  This class handles model configuration, loading, and inference
  for transcribing audio into musical note sequences.
  """

  def __init__(self, checkpoint_path, model_type='mt3'):
    # Set model-specific constants.
    if model_type == 'ismir2021':
      # For the ismir2021 model, which transcribes piano only with note velocities.
      num_velocity_bins = 127
      self.encoding_spec = note_sequences.NoteEncodingSpec
      self.inputs_length = 512
    elif model_type == 'mt3':
      # For the mt3 model, which transcribes multiple instruments without velocities.
      num_velocity_bins = 1
      self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
      self.inputs_length = 256
    else:
      raise ValueError('unknown model_type: %s' % model_type)

    # List the gin configuration files for the model.
    gin_files = ['/content/mt3/gin/model.gin',
                 f'/content/mt3/gin/{model_type}.gin']

    # Set additional hyperparameters.
    self.batch_size = 8
    self.outputs_length = 1024
    self.sequence_length = {'inputs': self.inputs_length,
                            'targets': self.outputs_length}

    # Create a partitioner for distributed model execution (here with 1 partition).
    self.partitioner = t5x.partitioning.PjitPartitioner(
        num_partitions=1)

    # Build spectrogram configuration and codecs for converting note sequences.
    self.spectrogram_config = spectrograms.SpectrogramConfig()
    self.codec = vocabularies.build_codec(
        vocab_config=vocabularies.VocabularyConfig(
            num_velocity_bins=num_velocity_bins))
    # Build vocabulary from the codec.
    self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
    # Define output features for the model (inputs are continuous; targets use the vocabulary).
    self.output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=self.vocabulary),
    }

    # Parse the gin configuration files.
    self._parse_gin(gin_files)
    # Load the model based on the configuration.
    self.model = self._load_model()

    # Restore the model's weights and optimizer state from a checkpoint.
    self.restore_from_checkpoint(checkpoint_path)

  @property
  def input_shapes(self):
    """
    Return the expected shapes of the input tensors for the model.
    This is used for initializing the training state.
    """
    return {
          'encoder_input_tokens': (self.batch_size, self.inputs_length),
          'decoder_input_tokens': (self.batch_size, self.outputs_length)
    }

  def _parse_gin(self, gin_files):
    """
    Parse the provided gin configuration files to configure the model.
    Gin files contain hyperparameters and other configuration settings.
    """
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
    ]
    # Unlock the gin configuration to allow modifications, then parse files.
    with gin.unlock_config():
      gin.parse_config_files_and_bindings(
          gin_files, gin_bindings, finalize_config=False)

  def _load_model(self):
    """
    Load a T5X model after gin configurations have been parsed.
    Returns the model instance ready for inference.
    """
    # Retrieve the model configuration from gin.
    model_config = gin.get_configurable(network.T5Config)()
    # Instantiate a Transformer module with the model configuration.
    module = network.Transformer(config=model_config)
    # Create the final Continuous Inputs Encoder-Decoder model,
    # specifying the module, vocabularies, optimizer, and input depth.
    return models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=self.output_features['inputs'].vocabulary,
        output_vocabulary=self.output_features['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(self.spectrogram_config))

  def restore_from_checkpoint(self, checkpoint_path):
    """
    Restore the model's state from the specified checkpoint.
    This resets the prediction function and initializes training state.
    """
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=self.model.optimizer_def,
      init_fn=self.model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)

    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')

    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))

  @functools.lru_cache()
  def _get_predict_fn(self, train_state_axes):
    """
    Generate a partitioned prediction function for decoding.
    This function is cached to avoid re-partitioning on each call.
    """
    def partial_predict_fn(params, batch, decode_rng):
      return self.model.predict_batch_with_aux(
          params, batch, decoder_params={'decode_rng': None})
    return self.partitioner.partition(
        partial_predict_fn,
        in_axis_resources=(
            train_state_axes.params,
            t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )

  def predict_tokens(self, batch, seed=0):
    """
    Predict token sequences from a preprocessed batch.
    Decodes the predicted tokens using the model's vocabulary.
    """
    prediction, _ = self._predict_fn(
        self._train_state.params, batch, jax.random.PRNGKey(seed))
    return self.vocabulary.decode_tf(prediction).numpy()

  def __call__(self, audio):
    """
    Infer a note sequence from raw audio samples.

    Args:
      audio: 1-D numpy array of audio samples (16kHz) for a single example.

    Returns:
      A note_sequence object representing the transcribed audio.
    """
    # Convert the audio into a dataset.
    ds = self.audio_to_dataset(audio)
    # Preprocess the dataset.
    ds = self.preprocess(ds)
    # Convert the dataset into batches and run predictions.
    model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=self.sequence_length)
    model_ds = model_ds.batch(self.batch_size)
    # Predict tokens for each batch.
    inferences = (tokens for batch in model_ds.as_numpy_iterator()
                  for tokens in self.predict_tokens(batch))
    predictions = []
    # Pair each example with its predicted tokens.
    for example, tokens in zip(ds.as_numpy_iterator(), inferences):
      predictions.append(self.postprocess(tokens, example))
    # Convert predicted events into a note sequence.
    result = metrics_utils.event_predictions_to_ns(
        predictions, codec=self.codec, encoding_spec=self.encoding_spec)
    return result['est_ns']

  def audio_to_dataset(self, audio):
    """
    Create a TensorFlow Dataset of spectrograms from input audio.
    """
    frames, frame_times = self._audio_to_frames(audio)
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': frame_times,
    })

  def _audio_to_frames(self, audio):
    """
    Compute spectrogram frames from raw audio data.
    Pads the audio if necessary and splits it into frames.
    """
    frame_size = self.spectrogram_config.hop_width
    # Calculate the padding needed to complete the final frame.
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    # Split the audio into frames.
    frames = spectrograms.split_audio(audio, self.spectrogram_config)
    num_frames = len(audio) // frame_size
    # Compute the time for each frame.
    times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
    return frames, times

  def preprocess(self, ds):
    """
    Apply a chain of preprocessing functions to the dataset.
    This includes splitting tokens, adding dummy targets, and computing spectrograms.
    """
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=self.sequence_length,
            output_features=self.output_features,
            feature_key='inputs',
            additional_feature_keys=['input_times']),
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=self.spectrogram_config)
    ]
    for pp in pp_chain:
      ds = pp(ds)
    return ds

  def postprocess(self, tokens, example):
    """
    Process predicted tokens and prepare the result for conversion to a note sequence.
    Trims the end-of-sequence token and adjusts the start time.
    """
    tokens = self._trim_eos(tokens)
    start_time = example['input_times'][0]
    # Adjust start time to the nearest symbolic token step.
    start_time -= start_time % (1 / self.codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        'raw_inputs': []  # Placeholder for raw inputs (not used here).
    }

  @staticmethod
  def _trim_eos(tokens):
    """
    Trim the end-of-sequence token from the token array.
    """
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
      tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    return tokens


In [None]:
#@title Load Model
#@markdown The `ismir2021` model transcribes piano only, with note velocities.
#@markdown The `mt3` model transcribes multiple simultaneous instruments,
#@markdown but without velocities.

MODEL = "mt3" #@param["ismir2021", "mt3"]
checkpoint_path = f'/content/checkpoints/{MODEL}/'
inference_model = InferenceModel(checkpoint_path, MODEL)



In [None]:
!find /content/mt3 -type f

In [None]:
!git clone https://github.com/jwdj/EasyABC.git

In [None]:
#@title Install flask and ngrok for http request and services
# !pip install flask pyngrok
!pip install fastapi uvicorn python-multipart pyngrok
!rm ngrok
!wget -q -O ngrok.zip https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-v3-stable-linux-amd64.zip
!unzip -o ngrok.zip
!./ngrok config add-authtoken 2tPc2RDNkkP31fLJon1HCWm0gjV_2V7dVSHND54Cc51Y3CWs9

In [None]:
#@title Python App Code (FastAPI with Improved Security, No Shutdown Endpoint)

import os
import sys
import uuid
import subprocess
import uvicorn
import threading
import nest_asyncio
import logging
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pyngrok import ngrok, conf
from note_seq import notebook_utils, sequence_proto_to_midi_file
import tensorflow as tf
from note_seq.audio_io import wav_data_to_samples_librosa
from werkzeug.utils import secure_filename

# Apply nest_asyncio so that uvicorn can run in a notebook environment
nest_asyncio.apply()

# ---------------------------
# Logging Configuration
# ---------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("FastAPI-Security")

# ---------------------------
# ngrok Configuration
# ---------------------------
conf.get_default().ngrok_path = "ngrok"
NGROK_TOKEN = "2tPc2RDNkkP31fLJon1HCWm0gjV_2V7dVSHND54Cc51Y3CWs9"  # Replace with your actual token
ngrok.set_auth_token(NGROK_TOKEN)

# Start ngrok tunnel on port 5000 with custom hostname
public_url = ngrok.connect(5000, bind_tls=True, hostname="secure-darling-minnow.ngrok-free.app").public_url
logger.info(f" * ngrok tunnel: {public_url}")

# ---------------------------
# Global Values
# ---------------------------
SAMPLE_RATE = 16000
SF2_PATH    = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

# Allowed origins for CORS, including localhost
ALLOWED_ORIGINS = [
    "http://localhost",
    "http://localhost:5000",
    "http://localhost:8080",
    public_url  # For testing via ngrok
]

# ---------------------------
# Model Loading Placeholder
# ---------------------------
def load_model():
    global inference_model
    if inference_model is None:
        raise ValueError("Model not loaded. Please run the 'Load Model' cell first.")
    return inference_model

# ---------------------------
# Audio Transcription Function
# ---------------------------
def transcribe_audio(file_path: str) -> dict:
    """
    Transcribe audio to MIDI and return a dictionary.
    This function loads the audio file, processes it using the model to generate a MIDI representation,
    writes the resulting MIDI to /tmp folder, and then returns the MIDI data as a dictionary.
    """
    print(file_path)

    # Open the specified file in binary mode and read its content.
    with open(file_path, 'rb') as f:
        audio_data = f.read()

    # Convert the raw audio data to a waveform using a helper function from note_seq.
    # The SAMPLE_RATE constant defines the number of samples per second.
    audio = wav_data_to_samples_librosa(audio_data, sample_rate=SAMPLE_RATE)

    # Load the transcription model. This function is expected to return a pre-loaded model.
    model = load_model()

    # Process the audio data using the loaded model to obtain an estimated note sequence.
    est_ns = model(audio)

    # Define the path where the generated MIDI file will be stored.
    # midi_file_path = '/tmp/transcribed.mid'

    # Convert the estimated note sequence (est_ns) to a MIDI file and save it to the specified path.
    sequence_proto_to_midi_file(est_ns, file_path)


# ---------------------------
# Client MIDI File to ABC Conversion Function
# ---------------------------
def midi_to_abcnotation(midi_path: str) -> dict:
    """
    Convert the given MIDI to ABC notation using EasyABC's midi2abc.py script.
    Returns a dictionary with key 'abc_notation'.
    """
    try:
        conversion_command = [sys.executable, "/content/EasyABC/midi2abc.py", midi_path]
        conversion_result = subprocess.run(
            conversion_command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True
        )
        return {"abc_notation": conversion_result.stdout}
    except subprocess.CalledProcessError as e:
      logger.error("MIDI-to-ABC conversion failed.", exc_info=True)
      # raise an HTTPException so FastAPI returns a 500
      raise HTTPException(
          status_code=500,
          detail=f"Failed to convert MIDI to ABC: {e.stderr.strip()}"
      )

# ---------------------------
# FastAPI Application Setup
# ---------------------------
app = FastAPI()

# Set up CORS with restricted origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", response_class=HTMLResponse)
async def home() -> HTMLResponse:
    """
    Home endpoint that returns a simple HTML page.
    """
    html_content = (
        "<h1>Music Transcription API - Midi -> ABC Notation</h1>"
        "<p>Upload audio files to /upload</p>"
    )
    return HTMLResponse(content=html_content)

@app.get("/test-data")
async def test_data():
    """
    Returns transcribed MIDI data in ABC notation as JSON.
    """
    try:
        midi_data = midi_to_abcnotation()
        return JSONResponse(content={"transcription": midi_data})
    except Exception as e:
        logger.exception("Error in /test-data endpoint.")
        return JSONResponse(content={"error": str(e)}, status_code=500)

@app.post("/upload")
async def upload_audio(file: UploadFile = File(...)):
    """
    Accepts an uploaded audio file (mp3 or wav), processes it to generate MIDI,
    converts the MIDI to ABC notation, removes the temporary MIDI file,
    and returns the ABC notation as JSON.
    """
    # Validate the file extension to ensure it's an MP3 or WAV.
    if not (file.filename.lower().endswith(('.mp3', '.wav', '.mid', '.midi'))):
        raise HTTPException(status_code=400, detail="Invalid file type")

    # Extract original extension and generate a unique filename
    ext = os.path.splitext(file.filename)[1]
    unique_name = f"{uuid.uuid4().hex}{ext}"

    # Create full path in temp dir
    file_path = os.path.join("/tmp", unique_name)

    # Open the destination file in binary write mode and save the uploaded content.
    with open(file_path, "wb") as f:
        # Asynchronously read the file content.
        content = await file.read()
        # Write the file content to disk.
        f.write(content)

    # 1) if audio -> transcribe
    if file.filename.lower().endswith(('.mp3', '.wav')):
        try:
            transcribe_audio(file_path)
        except Exception as e:
            logger.error("Audio transcription failed.", exc_info=True)
            raise HTTPException(
                status_code=500,
                detail=f"Audio transcription error: {str(e)}"
            )

    # 2) convert MIDI -> ABC (raises HTTPException on fail)
    midi_data = midi_to_abcnotation(file_path)

    # 3) cleanup temp file
    try:
        os.remove(file_path)
    except OSError:
        logger.warning("Could not delete temp file %s", file_path)

    return JSONResponse(content={"transcription": midi_data})

# ---------------------------
# Run the Server Directly
# ---------------------------
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=5000, log_level="info")


In [None]:
#@title Kill Port to restart server

# List current ngrok tunnels
# !curl http://127.0.0.1:4040/api/tunnels


# Kill the uvicorn process
# !kill 51096

# Check which processes are listening on port 5000
!netstat -tulpn | grep 5000

In [None]:
#@title Upload Audio

# audio = upload_audio(sample_rate=SAMPLE_RATE)

# note_seq.notebook_utils.colab_play(audio, sample_rate=SAMPLE_RATE)

In [None]:
#@title Transcribe Audio
#@markdown This may take a few minutes depending on the length of the audio file
#@markdown you uploaded.

# est_ns = inference_model(audio)

# note_seq.play_sequence(est_ns, synth=note_seq.fluidsynth,
#                        sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
# note_seq.plot_sequence(est_ns)

In [None]:
#@title Download MIDI Transcription

# note_seq.sequence_proto_to_midi_file(est_ns, '/tmp/transcribed.mid')
# files.download('/tmp/transcribed.mid')