<a href="https://colab.research.google.com/github/asigalov61/Music-Reformer/blob/main/Music_Reformer_TPU_Edition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Reformer (v.1.0): TPU Edition

### This is a work in progress so please check back for updates and improvements.

***

### Based on the offical Reformer Google Colab and code.

***

Project Los Angeles

Tegridy Code 2021

***

# Setup the environment

In [None]:
#@title Install the dependencies
# Install dependencies

!git clone https://github.com/asigalov61/tegridy-tools
%cd /content/tegridy-tools/tegridy-tools/
%cd /content/

!wget https://github.com/asigalov61/Music-Reformer/raw/main/Dataset/Music-Reformer_TXT_Dataset.zip
!unzip Music-Reformer_TXT_Dataset.zip

!pip install --upgrade -q jax
!pip install --upgrade -q jaxlib
!pip install --upgrade -q trax==1.3.6
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin 

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [None]:
#@title Import modules
import gin
import os
import numpy as np
import torch
from scipy.special import softmax

%cd /content/tegridy-tools/tegridy-tools
import TMIDI
%cd /content/


# Zipping and downloading files
from google.colab import files
import shutil

# Trax
import jax
import trax
from trax.data import inputs
import jax.numpy as jnp

# NLP Vocab Generation
import sentencepiece as spm

# Prep the dataset

In [None]:
#@title Process the TXT MIDI dataset to plain TXT data
with open('/content/Music-Reformer_TXT_Dataset.txt', 'rb') as file:
  z = file.read().split()
  out = []
  for i in range(len(z)):
    try:
      out.append(str(z[i].decode('utf-8')))
    except:
      continue

out1 = '\n'.join(out)

X = np.fromstring(out1, dtype='int8')
output = ''

for y in X:
  try:
    output += str(abs(y)) + '\n'
  except:
    continue

with open('/content/Music-Reformer_INT_Dataset.txt', 'w') as f:
  f.write(output)

In [None]:
#@title Create a tokenizer and its model
# Train a BPE model on the dataset
spm.SentencePieceTrainer.train('--input=/content/Music-Reformer_TXT_Dataset.txt \
                                --model_prefix=Music-Reformer-Tokenizer \
                                --vocab_size=3600 \
                                --model_type=bpe')
# Load BPE vocabulary
TOKENIZER = spm.SentencePieceProcessor() 
TOKENIZER.load('Music-Reformer-Tokenizer.model')

# Load the dataset
with open('/content/Music-Reformer_TXT_Dataset.txt') as f:
    text = f.read()


IDS = TOKENIZER.EncodeAsIds(text[:1100000])
IDS = np.asarray(IDS, dtype=np.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])

In [None]:
#@title Split the dataset

# Tokenize (set to max for the provided dataset)
trX, vaX = np.split(X, [1100000])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# Setup the Reformer model and functions

In [None]:
#@title Initialize the functions and procedures for training
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]), # Pad IDS by different amount for each device
                            mode='constant'))
      mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = np.stack(inputs)
    mask = np.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.fastmath.device_count()))[0].shape)

In [None]:
#@title Configure hyperparamenters
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Model will have 6 layers, alternating between the LSH attention
# and local attention within a certain context window.
n_layers = 6
attn_type = [
  @trax.layers.SelfAttention,
  @LSHSelfAttention,  
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  @trax.layers.SelfAttention,
  @LSHSelfAttention,
  ]
share_qk = False  # LSH attention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288

# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.01
multifactor.factors = 'constant * linear_warmup * cosine_decay'
multifactor.warmup_steps = 100
multifactor.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for SelfAttention:
# ==============================================================================
trax.layers.SelfAttention.attention_dropout = 0.05
trax.layers.SelfAttention.chunk_len = 64
trax.layers.SelfAttention.n_chunks_before = 1
trax.layers.SelfAttention.n_parallel_heads = 1

# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 64
LSHSelfAttention.n_buckets = [64, 128]
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 1
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 128
LSHSelfAttention.predict_mem_len = 1024

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [None]:
#@title Setup the model and the trainer routines
# Trainer.
output_dir = os.path.expanduser('model')
!rm -f ~/model/model.pkl.gz  # Remove old model

trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

# Train

In [None]:
#@title Train the model
# Train Model
import tqdm
for _ in tqdm.tqdm(range(50)):
  trainer.train_epoch(n_steps=100, n_eval_steps=20)

In [None]:
#@title Zip and download your trained model checkpoint here
# Zip directory contents
shutil.make_archive("project", "zip", ".")

# Download zipped directory
files.download('project.zip')

# Generate Music

In [None]:
#@title Increase hashing rounds number for better quality here
# In the Reformer paper, increasing the number of hashing rounds helps with quality. 
# The number of hashing rounds at can be increased at evaluation time only.
gin.parse_config("""LSHSelfAttention.n_hashes = 4""")

In [None]:
#@title Load the trained Reformer in 'predict' mode
# Load the trained Reformer in 'predict' mode
model = trax.models.ReformerLM(mode='predict')
output_dir = os.path.expanduser('model')
model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),
                     weights_only=True)

In [None]:
#@title Generate and decode music from the model
# Sample from ReformerLM
output_token_ids = trax.supervised.decoding.autoregressive_sample(
    model, temperature=0.8, max_length=1024, batch_size = 4)

# Decode token IDs
# Reformer outputed a batch with one item so access it using [0]
# tolist() converts from int64 to int, the type SentencePiece expects
input = TOKENIZER.DecodeIds(output_token_ids[0].tolist())

In [None]:
#@title Convert generated output to MIDI.
# Run the cells below to convert generated output to MIDI.
# If you getting errors/halts, regenerate the output again.
# Model must be sufficiently trained. Rec. 0.90+ accuracy for the output to make sense and pass error control.

TXT = TMIDI.Tegridy_INT_String_to_TXT_Converter(input, line_by_line_input=False)
SONG = TMIDI.Tegridy_Reduced_TXT_to_Notes_Converter(TXT, has_MIDI_channels=False, has_velocities=False, dataset_includes_beat=True)
stats = TMIDI.Tegridy_SONG_to_MIDI_Converter(SONG=SONG[0], output_file_name='/content/Music-Reformer_MIDI')
print(stats)

# Congrats!!! You did it!!! :)