##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

# Generating Piano Music with Transformer
### ___Ian Simon, Anna Huang, Jesse Engel, Curtis "Fjord" Hawthorne___

This Colab notebook lets you play with pretrained [Transformer](https://arxiv.org/abs/1706.03762) models for piano music generation, based on the [Music Transformer](http://g.co/magenta/music-transformer) model introduced by [Huang et al.](https://arxiv.org/abs/1809.04281) in 2018.

The models used here were trained on over 10,000 hours of piano recordings from YouTube, transcribed using [Onsets and Frames](http://g.co/magenta/onsets-frames) and represented using the event vocabulary from [Performance RNN](http://g.co/magenta/performance-rnn).

Unlike the original Music Transformer paper, this notebook uses attention based on absolute instead of relative position; we may add models that use relative attention at some point in the future.

# Environment Setup

In [None]:
#@title Setup Environment
#@markdown Copy some auxiliary data from Google Cloud Storage.
#@markdown Also install and import Python dependencies needed
#@markdown for running the Transformer models.

%tensorflow_version 1.x

print('Copying Salamander piano SoundFont (via https://sites.google.com/site/soundfonts4u) from GCS...')
!gsutil -q -m cp -r gs://magentadata/models/music_transformer/primers/* /content/
!gsutil -q -m cp gs://magentadata/soundfonts/Yamaha-C5-Salamander-JNv5.1.sf2 /content/

print('Installing dependencies...')
!apt-get update -qq && apt-get install -qq libfluidsynth1 build-essential libasound2-dev libjack-dev
!pip install -q 'tensorflow-datasets < 4.0.0'
!pip install -qU google-cloud magenta pyfluidsynth

print('Importing libraries...')

import numpy as np
import os
import tensorflow.compat.v1 as tf

from google.colab import files

from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_lib

from magenta.models.score2perf import score2perf
import note_seq

tf.disable_v2_behavior()

print('Done!')

In [None]:
#@title Definitions
#@markdown Define a few constants and helper functions.

SF2_PATH = '/content/Yamaha-C5-Salamander-JNv5.1.sf2'
SAMPLE_RATE = 16000

# Upload a MIDI file and convert to NoteSequence.
def upload_midi():
  data = list(files.upload().values())
  if len(data) > 1:
    print('Multiple files uploaded; using only one.')
  return note_seq.midi_to_note_sequence(data[0])

# Decode a list of IDs.
def decode(ids, encoder):
  ids = list(ids)
  if text_encoder.EOS_ID in ids:
    ids = ids[:ids.index(text_encoder.EOS_ID)]
  return encoder.decode(ids)
  

# Piano Performance Language Model

In [None]:
#@title Setup and Load Checkpoint
#@markdown Set up generation from an unconditional Transformer
#@markdown model.

model_name = 'transformer'
hparams_set = 'transformer_tpu'
ckpt_path = 'gs://magentadata/models/music_transformer/checkpoints/unconditional_model_16.ckpt'

class PianoPerformanceLanguageModelProblem(score2perf.Score2PerfProblem):
  @property
  def add_eos_symbol(self):
    return True

problem = PianoPerformanceLanguageModelProblem()
unconditional_encoders = problem.get_feature_encoders()

# Set up HParams.
hparams = trainer_lib.create_hparams(hparams_set=hparams_set)
trainer_lib.add_problem_hparams(hparams, problem)
hparams.num_hidden_layers = 16
hparams.sampling_method = 'random'

# Set up decoding HParams.
decode_hparams = decoding.decode_hparams()
decode_hparams.alpha = 0.0
decode_hparams.beam_size = 1

# Create Estimator.
run_config = trainer_lib.create_run_config(hparams)
estimator = trainer_lib.create_estimator(
    model_name, hparams, run_config,
    decode_hparams=decode_hparams)

# Create input generator (so we can adjust priming and
# decode length on the fly).
def input_generator():
  global targets
  global decode_length
  while True:
    yield {
        'targets': np.array([targets], dtype=np.int32),
        'decode_length': np.array(decode_length, dtype=np.int32)
    }

# These values will be changed by subsequent cells.
targets = []
decode_length = 0

# Start the Estimator, loading from the specified checkpoint.
input_fn = decoding.make_input_fn_from_generator(input_generator())
unconditional_samples = estimator.predict(
    input_fn, checkpoint_path=ckpt_path)

# "Burn" one.
_ = next(unconditional_samples)

In [None]:
#@title Generate from Scratch
#@markdown Generate a piano performance from scratch.
#@markdown
#@markdown This can take a minute or so depending on the length
#@markdown of the performance the model ends up generating.
#@markdown Because we use a 
#@markdown [representation](http://g.co/magenta/performance-rnn)
#@markdown where each event corresponds to a variable amount of
#@markdown time, the actual number of seconds generated may vary.

targets = []
decode_length = 1024

# Generate sample events.
sample_ids = next(unconditional_samples)['outputs']

# Decode to NoteSequence.
midi_filename = decode(
    sample_ids,
    encoder=unconditional_encoders['targets'])
unconditional_ns = note_seq.midi_file_to_note_sequence(midi_filename)

# Play and plot.
note_seq.play_sequence(
    unconditional_ns,
    synth=note_seq.fluidsynth, sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)
note_seq.plot_sequence(unconditional_ns)

In [None]:
#@title Download Performance as MIDI
#@markdown Download generated performance as MIDI (optional).

note_seq.sequence_proto_to_midi_file(
    unconditional_ns, '/tmp/unconditional.mid')
files.download('/tmp/unconditional.mid')