In [None]:
#@title 0. Preparing the environment
#@markdown Libraries are simply downloaded and imported here, nothing interesting
!pip install gdown > /dev/null
!apt install fluidsynth > /dev/null
!pip install midi2audio > /dev/null
!pip install pretty_midi > /dev/null
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2
!git clone https://github.com/ai-forever/music-composer.git

import os
import sys
import time
import glob
import torch
import gdown
import webbrowser
import ipywidgets
import pretty_midi
import numpy as np

from tqdm import tqdm
from pathlib import Path
from midi2audio import FluidSynth
from google.colab import widgets, files
from IPython.display import Audio, display, FileLink, HTML
from ipywidgets import AppLayout, HBox, VBox, Label

url = 'https://drive.google.com/u/0/uc?id=1lcNp0y4IZMIos0ASSsERG25WVDuEmLks'
output = 'model_finetune_700k.pt'
gdown.download(url, output, quiet=True)


sys.path.append('/content/music-composer/src/')
from lib import constants
from lib import generation
from lib import midi_processing
from lib.midi_processing import PIANO_RANGE
from lib.model.transformer import MusicTransformer
from lib.colab_utils import id2genre, rugenre, genre2id, decode_and_write, convert_midi_to_wav, DownloadButton

In [None]:
#@title 1. Generation control panel
#@markdown In this panel, you can customize the music generation parameters for yourself. Let's go through all the points:
#@markdown * You can choose __primer__ - it's your MIDI file, which will be taken as the beginning of the song. If it is long, then it will be trimmed. Specify the desired length (in seconds) and how to trim it - from the beginning, or from the end (From start / From end). __From start__ means that we will leave the first N seconds, __From end__ - after N seconds.
#@markdown * __Genre__ - allows you to select the genre in which the tracks will be generated. For each of the genres, we have selected our own set of parameters, which allows you to best generate it, but more on that later.
#@markdown * __Seed__ - allows you to fix the seed with which the tracks are generated. Fixing it can be useful if you want to study the influence of a particular parameter on generation - set a numerical seed and generate several times changing the studied parameter. In other cases, it is recommended to leave this field blank.  
#@markdown * __Batch_size__ - the number of simultaneously generated tracks. If you overdo it, you may not have enough memory. Here you need to find a balance between the number of simultaneous tracks and their length.  
#@markdown * __Sequence length__ - track length. Due to the peculiarities of generation, we cannot say in advance how long the track will last, but increasing this parameter potentially increases the track length. It is with this parameter that the Batch_size must be balanced so that the video memory doesn't end.
#@markdown * __Remove bad generations__ - sometimes the model starts generating garbage. If this flag is enabled, we will try to detect it, filter it out and generate new compositions instead. The generation time increases accordingly.
#@markdown * __Temperature__ - temperature scaling. Affects the probability distribution itself, making it more or less equiprobable. This regulates the variety of generation.
#@markdown * __TopK__ - restriction of dialing on the upper border. Sampling comes from a set of k most likely tokens.
#@markdown * __At least K__ - limitation of dialing on the lower border. Sampling will be guaranteed to come from at least the most likely tokens at_least_k.
#@markdown * __TopP__ - the topp parameter about which you can read in more detail in the article.
#@markdown * __TopP Temperature__ - temperature scaling applied after selection by topp criterion.
#@markdown * __Use Repetition Penalty__ - flag to use note repetition penalties. It is recommended to turn it on only if the model generates cyclic boring tracks.
#@markdown * __RP Penalty__ - the amount of penalties for repeating notes. The higher, the less the model gets stuck in some musical phrases.
#@markdown * __Restore speed__ - speed of recovery after fines of the repetition penalty module

#@markdown There are also buttons by genre in the lower right corner. With the help of them, you can return to the generation parameters we have selected for each of the genres.


def truncate_midi(midi, primer_len_sec=15.0, from_end=False):
    time0 = max([inst.notes[-1].end for inst in midi.instruments]) if from_end else 0

    for inst in midi.instruments:
        notes = sorted(inst.notes, key=lambda x: x.start, reverse=from_end)
        for i,note in enumerate(notes):
            if np.abs(note.start - time0) > primer_len_sec:
                break
            if not from_end and note.end > primer_len_sec:
                note.end = time0 - (int(from_end)*2-1) * primer_len_sec
        inst.notes = notes[:i]
        if from_end:
            inst.notes = inst.notes[::-1]


style = {'description_width': 'initial'}
genre_to_generate = ipywidgets.Dropdown(
    options=['calm', 'jazz', 'pop', 'classic'],
    value='calm',
    description='Genre:',
    disabled=False,
)
genre_to_generate.default_value = 'calm'

seed = ipywidgets.Text(
    value='',
    placeholder='leave blank for random seed',
    description='Seed:',
    disabled=False
)
seed.default_value = ''

temp = ipywidgets.FloatSlider( 
    value=1.0,
    min=0.01,
    max=5.0,
    step=0.01,
    description='Temperature:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style
)
temp.default_value = 1.0

b_size = ipywidgets.IntSlider(
    value=8,
    min=1,
    max=16,
    step=1,
    description='Batch size:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style
)
b_size.default_value = 8

seq_length = ipywidgets.IntSlider(
    value=512,
    min=256,
    max=2048,
    step=256,
    description='Sequence length:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style
)
seq_length.default_value = 512

topk = ipywidgets.IntSlider(
    value=60,
    min=1,
    max=300,
    step=1,
    description='Top k:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style
)
topk.default_value = 60

at_least_k = ipywidgets.IntSlider(
    value=1,
    min=1,
    max=300,
    step=1,
    description='At least k:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style
)
at_least_k.default_value = 1

topp = ipywidgets.FloatSlider( 
    value=0.99,
    min=0.5,
    max=1.0,
    step=0.01,
    description='Topp:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style
)
topp.default_value = 0.99

topp_temperature = ipywidgets.FloatSlider( 
    value=1.0,
    min=0.01,
    max=5.0,
    step=0.01,
    description='Topp Temperature:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style
)
topp_temperature.default_value = 1.0

use_rp = ipywidgets.Checkbox(
    value=False,
    description='Use Repetition Penalty',
    disabled=False,
    style=style
)
use_rp.default_value = False

rp_penalty = ipywidgets.FloatSlider( 
    value=0.05,
    min=0.,
    max=1.0,
    step=0.05,
    description='RP penalty:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style
)
rp_penalty.default_value = 0.05

restore_speed = ipywidgets.FloatSlider( 
    value=0.7,
    min=0.,
    max=1.0,
    step=0.05,
    description='Restore speed:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style
)
restore_speed.default_value = 0.7

remove_bad_generations = ipywidgets.Checkbox(
    value=True,
    description='Remove bad generations',
    disabled=False,
    style=style
)
remove_bad_generations.default_value = True

defaulting_widgets = (genre_to_generate, seed, b_size, seq_length, remove_bad_generations,
                      temp, topk, at_least_k, topp, topp_temperature, use_rp, rp_penalty, restore_speed)
def set_classic(button):
    for idx, widget in enumerate(defaulting_widgets):
      if idx == 0:
        widget.value = 'classic'
      elif idx == 5:
        widget.value = 1.0
      elif idx == 7:
        widget.value = 1
      elif idx == 8:
        widget.value = 0.99
      else:
        widget.value = widget.default_value

def set_calm(button):
    for idx, widget in enumerate(defaulting_widgets):
      if idx == 0:
        widget.value = 'calm'
      elif idx == 5:
        widget.value = 1.03
      elif idx == 7:
        widget.value = 4
      elif idx == 8:
        widget.value = 0.98
      else:
        widget.value = widget.default_value

def set_jazz(button):
    for idx, widget in enumerate(defaulting_widgets):
      if idx == 0:
        widget.value = 'jazz'
      elif idx == 5:
        widget.value = 0.99
      elif idx == 7:
        widget.value = 1
      elif idx == 8:
        widget.value = 0.99
      else:
        widget.value = widget.default_value

def set_pop(button):
    for idx, widget in enumerate(defaulting_widgets):
      if idx == 0:
        widget.value = 'pop'
      elif idx == 5:
        widget.value = 0.98
      elif idx == 7:
        widget.value = 4
      elif idx == 8:
        widget.value = 0.99
      else:
        widget.value = widget.default_value

calm_value_button = ipywidgets.Button(description='Calm')
calm_value_button.on_click(set_calm)
jazz_value_button = ipywidgets.Button(description='Jazz')
jazz_value_button.on_click(set_jazz)
pop_value_button = ipywidgets.Button(description='Pop')
pop_value_button.on_click(set_pop)
classic_value_button = ipywidgets.Button(description='Classic')
classic_value_button.on_click(set_classic)
buttons = (calm_value_button, jazz_value_button, pop_value_button, classic_value_button)
left_box = ipywidgets.VBox((buttons[0], buttons[1]))
right_box = ipywidgets.VBox((buttons[2], buttons[3]))

def update_dropdown():
    select_primer_widget.options = ['None'] + sorted(map(str, Path('primers').glob('*.*')))

def on_upload(x):
    x = x['new']
    names = x.keys()
    os.makedirs('primers', exist_ok=True)
    for name in names:
        content = x[name]['content']
        if content:
            path = 'primers/'+name
            with open(path, "wb") as fp:
                fp.write(content)
        try:
            pretty_midi.PrettyMIDI(path)
        except:
            print(f'file "{name}" is corrupted or not a MIDI!')
            os.remove(path)
    update_dropdown()

select_primer_widget = ipywidgets.Dropdown()
update_dropdown()

uploader = ipywidgets.FileUpload(description='Upload MIDI', multiple=True)
uploader.observe(on_upload, names='value')

primer_len = ipywidgets.FloatSlider(
    value=15,
    min=0,
    max=60,
    step=0.1,
    description='Seconds:',
    continuous_update=False,
    style=style
)

primer_position = ipywidgets.RadioButtons(options=[['From start',0],['From end',1]], orientation='horizontal')

AppLayout(header=ipywidgets.HBox([Label('Select primer:'), select_primer_widget, uploader, primer_len, primer_position]),
          left_sidebar=VBox([genre_to_generate, seed, b_size, seq_length, remove_bad_generations]),
          center=VBox([temp, topk, at_least_k, topp, topp_temperature]),
          right_sidebar=VBox([use_rp, rp_penalty, restore_speed, ipywidgets.HBox((left_box, right_box))]),
          footer=None,
          pane_widths=[1, 1, 1],
          pane_heights=[1, 5, '60px'])

In [None]:
#@title 2. Start generating
load_path = '/content/model_finetune_700k.pt'
out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S') 
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

if device == 'cpu':
  print('Generating on CPU. Expect lower generation speed.')

os.makedirs(out_dir, exist_ok=True)
genre_id = genre2id[genre_to_generate.value]

params = dict(
    target_seq_length = seq_length.value,
    temperature = temp.value,
    topk = topk.value,
    topp = topp.value,
    topp_temperature = topp_temperature.value,
    at_least_k = at_least_k.value,
    use_rp = use_rp.value,
    rp_penalty = rp_penalty.value,
    rp_restore_speed = restore_speed.value,
    seed = int(seed.value) if seed.value else None
)
max_primer_tokens = 512

# Init model
print('loading model...')
model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()
model.load_state_dict(torch.load(load_path, map_location=device))

# Add genre and primer
primer_genre = np.repeat([genre_id], b_size.value)[:,None] + constants.VOCAB_SIZE - 4
if primer_len.value > 0 and select_primer_widget.value != 'None':
    file = select_primer_widget.value
    midi = pretty_midi.PrettyMIDI(file)
    from_end = primer_position.value == 1
    truncate_midi(midi, primer_len.value, from_end)  # truncate to specified length (in seconds)
    encoded = midi_processing.encode(midi)
    # truncate to max_primer_tokens (in tokens)
    l = len(encoded)
    if l > max_primer_tokens:
        import warnings
        warnings.warn('Primer MIDI is too long (length > 512), it will be truncated to 512!')
        if from_end:
            encoded = encoded[l-max_primer_tokens:]
        else:
            encoded = encoded[:max_primer_tokens]
    primer_seq = np.repeat(np.array(encoded)[None], b_size.value, 0)
    primer = np.concatenate([primer_genre, primer_seq], -1)
else:
    primer = primer_genre
primer = torch.tensor(primer, dtype=torch.int64)

# Generation
if primer.shape[-1] >= seq_length.value:
    print('Nothing to generate. Try to set larger "Sequence length" parameter!')
else:
    while len(glob.glob(out_dir + '/*.mid')) != b_size.value:
      print('generating to:', os.path.abspath(out_dir))
      generated = generation.generate(model, primer, **params)
      generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations.value)
      decode_and_write(generated, primer, primer_genre.squeeze(-1)-390, out_dir)
      files_to_delete = len(glob.glob(out_dir + '/*.mid')) - b_size.value
      if files_to_delete > 0:
        for idx in range(files_to_delete):
          os.remove(sorted(glob.glob(out_dir + '/*.mid'), key=lambda x: int(x.split('_')[-2]))[-idx])
    for midi_name in glob.glob(out_dir + '/*.mid'):
      convert_midi_to_wav(midi_name)

In [None]:
#@title 3. Listen and download the generation results
wav_files = glob.glob(out_dir + '/*.wav')
if len(wav_files) > 3:
  rows = round(len(wav_files) / 3) + 1
  columns = 3
else:
  rows = 1
  columns = len(wav_files)
grid = widgets.Grid(rows, columns)
current_position = 0
for row in range(rows):
  if current_position > len(wav_files) - 1:
      break
  for column in range(columns):
    if current_position > len(wav_files) - 1:
      break
    with grid.output_to(row, column):
      genre = rugenre[wav_files[current_position].split('.wav')[0].split('_')[-1]]
      print(f'Номер трека: {current_position + 1}')
      print(f'Жанр: {genre}')
      display(Audio(wav_files[current_position]))
      display(DownloadButton(filename=wav_files[current_position], description='Скачать .wav'))
      display(DownloadButton(filename=wav_files[current_position].replace('.wav', '.mid'), description='Скачать .midi'))
    current_position += 1