<a href="https://colab.research.google.com/github/asigalov61/SuperPiano/blob/master/SuperPiano5_(Reformer_TPU).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super Piano 5 (v.0.5)

Uses line-by-line dataset from Intelligent VIRTUOSO/TMIDI-TXT Processor

https://github.com/asigalov61/Intelligent-VIRTUOSO

Most of the colab as found here:

https://github.com/jlilleberg/Philosophical-Text-Generation-using-Reformers

In [None]:
# Install dependencies
!pip install --upgrade -q jax
!pip install --upgrade -q jaxlib
!pip install --upgrade -q trax==1.3.6
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin 

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [None]:
import gin
import os
import numpy as np
from scipy.special import softmax

# Zipping and downloading files
from google.colab import files
import shutil

# Trax
import jax
import trax
from trax.data import inputs
import jax.numpy as jnp

# NLP Vocab Generation
import sentencepiece as spm

# TensorFlow
from tensorflow.compat.v1.io.gfile import GFile
import tensorflow as tf

In [None]:
# Download `The Republic` by Plato text
FILENAME = '.'.join(['the_republic', 'txt'])
URL = 'http://www.gutenberg.org/cache/epub/1497/pg1497.txt'
tf.keras.utils.get_file(FILENAME, URL, cache_dir='.')
TEXT_PATH = os.path.join('datasets', FILENAME)

In [None]:
# Use only novel text
with GFile('/content/Intelligent-Virtuoso-Music-TXT-Dataset1.csv') as f:
    text = f.read()

'''start = text.rfind('INTRODUCTION AND ANALYSIS')
start = text.find('The Republic', start + 1)
end = text.rfind('End of the Project Gutenberg EBook of The Republic, by Plato')
text = text[start:end].strip()'''

In [None]:
# Train a BPE model on the text
spm.SentencePieceTrainer.train('--input=/content/Intelligent-Virtuoso-Music-TXT-Dataset1.txt \
                                --model_prefix=cp.320 \
                                --vocab_size=320 \
                                --model_type=bpe')
# Load BPE vocabulary
TOKENIZER = spm.SentencePieceProcessor() 
TOKENIZER.load('cp.320.model')

In [None]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)
IDS = np.asarray(IDS, dtype=np.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])

In [None]:
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]), # Pad IDS by different amount for each device
                            mode='constant'))
      mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = np.stack(inputs)
    mask = np.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.fastmath.device_count()))[0].shape)

In [None]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Model will have 6 layers, alternating between the LSH attention
# and local attention within a certain context window.
n_layers = 6
attn_type = [
  @trax.layers.SelfAttention,
  @LSHSelfAttention,  
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  ]
share_qk = False  # LSH attention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288

# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.01
multifactor.factors = 'constant * linear_warmup * cosine_decay'
multifactor.warmup_steps = 100
multifactor.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for SelfAttention:
# ==============================================================================
trax.layers.SelfAttention.attention_dropout = 0.05
trax.layers.SelfAttention.chunk_len = 64
trax.layers.SelfAttention.n_chunks_before = 1
trax.layers.SelfAttention.n_parallel_heads = 1

# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 64
LSHSelfAttention.n_buckets = [64, 128]
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 1
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 128
LSHSelfAttention.predict_mem_len = 1024

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [None]:
# Trainer.
output_dir = os.path.expanduser('model')
!rm -f ~/model/model.pkl.gz  # Remove old model

trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

In [None]:
# Train Model
import tqdm
for _ in tqdm.tqdm(range(50)):
  trainer.train_epoch(n_steps=100, n_eval_steps=20)

In [None]:
# Zip directory contents
shutil.make_archive("project", "zip", ".")

# Download zipped directory
files.download('project.zip')

In [None]:
# In the Reformer paper, increasing the number of hashing rounds helps with quality. 
# The number of hashing rounds at can be increased at evaluation time only.
gin.parse_config("""LSHSelfAttention.n_hashes = 8""")

In [None]:
# Load the trained Reformer in 'predict' mode
model = trax.models.ReformerLM(mode='predict')
output_dir = os.path.expanduser('model')
model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),
                     weights_only=True)

# Sample from ReformerLM
output_token_ids = trax.supervised.decoding.autoregressive_sample(
    model, temperature=0.2, max_length=1024, accelerate=True)

# Decode token IDs
# Reformer outputed a batch with one item so access it using [0]
# tolist() converts from int64 to int, the type SentencePiece expects
input = TOKENIZER.DecodeIds(output_token_ids[0].tolist())

In [None]:
#@title Install all dependencies (run only once per session)
!git clone https://github.com/asigalov61/tegridy-tools
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os
if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

os.chdir('/content/minGPT')
# make deterministic
set_seed(42)

import tqdm.auto
import pickle
import numpy as np


import time
import math
import datetime
from datetime import datetime


os.chdir('/content/tegridy-tools/tegridy-tools')
import TMIDI
import MIDI

import tqdm.auto

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from midi2audio import FluidSynth
from IPython.display import display, Javascript, HTML, Audio

from google.colab import output, drive

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print('Available Processing Device is:', device)
os.chdir('/content/')
print('Loading complete. Enjoy! :)')

In [None]:
#@title Convert to MIDI from TXT (w/Tegridy MIDI-TXT Processor)

#@markdown Standard MIDI timings are 400/120(80)

#@markdown Silence offset is model specific and usually the same so most likely it will need to be set up only once. Otherwise, please play with the settings utill you will find the right ones for your particular model.

#@markdown line_by_line_dataset option is for support of legacy datasets (IV 1.0-1.1). From now on IV will be using new line-by-line datasets for compatibility with AI BPE tokenizers.

line_by_line_dataset = False #@param {type:"boolean"}
dataset_MIDI_events_time_denominator = 100 #@param {type:"slider", min:1, max:100, step:1}
number_of_ticks_per_quarter = 425 #@param {type:"slider", min:1, max:1000, step:8}
start_from_this_generated_event = 0 #@param {type:"slider", min:0, max:100, step:1}
remove_generated_silence_if_needed = False #@param {type:"boolean"}
silence_offset_from_start = 75000 #@param {type:"integer"}
simulate_velocity = False #@param {type:"boolean"}

song_score = []
n = 0
z = 0
detailed_stats = 0

#with open('/content/TMIDI-TXT-Composition.txt', 'r') as file:
#    input=file.read()

output_signature = 'Intelligent VIRTUOSO'

midi_data, n, z, detailed_stats = TMIDI.Tegridy_TXT_MIDI_Processor(input,
                                             line_by_line_dataset,
                                             dataset_MIDI_events_time_denominator,
                                             number_of_ticks_per_quarter,
                                             start_from_this_generated_event,
                                             remove_generated_silence_if_needed,
                                             silence_offset_from_start,
                                             simulate_velocity,
                                             output_signature)
                                            

# Stuff for datetime stamp
now = ''
now_n = str(datetime.now())
now_n = now_n.replace(' ', '_')
now_n = now_n.replace(':', '_')
now = now_n.replace('.', '_')
    
fname = '/content/Intelligent-VIRTUOSO-Composition-' + 'generated-on-' + str(now) + '.mid'  


with open(fname, 'wb') as midi_file:
    midi_file.write(midi_data)
    midi_file.close()
print('Done!')

from google.colab import files
files.download(fname)
print('Detailed MIDI stats:')
detailed_stats