In [None]:
import glob

BASE_DIR = "gs://download.magenta.tensorflow.org/models/music_vae/colab2"

print('Installing dependencies...')
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth
!pip install -qU magenta

# Hack to allow python to pick up the newly-installed fluidsynth lib.
# This is only needed for the hosted Colab environment.
import ctypes.util
orig_ctypes_util_find_library = ctypes.util.find_library
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return orig_ctypes_util_find_library(lib)
ctypes.util.find_library = proxy_find_library


print('Importing libraries and defining some helper functions...')
from google.colab import files
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

# Necessary until pyfluidsynth is updated (>1.2.5).
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

print('Done')

In [None]:
def play(note_sequence):
  mm.play_sequence(note_sequence, synth=mm.fluidsynth)

def interpolate(model, start_seq, end_seq, num_steps, max_length=32,
                assert_same_length=True, temperature=0.5,
                individual_duration=4.0):
  """Interpolates between a start and end sequence."""
  note_sequences = model.interpolate(
      start_seq, end_seq,num_steps=num_steps, length=max_length,
      temperature=temperature,
      assert_same_length=assert_same_length)

  #print('Start Seq Reconstruction')
  #play(note_sequences[0])
  #print('End Seq Reconstruction')
  #play(note_sequences[-1])
  #print('Mean Sequence')
  #play(note_sequences[num_steps // 2])
  #print('Start -> End Interpolation')
  interp_seq = mm.sequences_lib.concatenate_sequences(
      note_sequences, [individual_duration] * len(note_sequences))
  #play(interp_seq)
  #mm.plot_sequence(interp_seq)
  return interp_seq # if num_steps > 3 else note_sequences[num_steps // 2]

def download(note_sequence, filename):
  mm.sequence_proto_to_midi_file(note_sequence, filename)
  files.download(filename)

In [None]:
drums_models = {}
# One-hot encoded.
drums_config = configs.CONFIG_MAP['cat-drums_2bar_small']
drums_models['drums_2bar_oh_lokl'] = TrainedModel(drums_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR +\
                                                  '/checkpoints/drums_2bar_small.lokl.ckpt')
drums_models['drums_2bar_oh_hikl'] = TrainedModel(drums_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR +\
                                                  '/checkpoints/drums_2bar_small.hikl.ckpt')

# Multi-label NADE.
drums_nade_reduced_config = configs.CONFIG_MAP['nade-drums_2bar_reduced']
drums_models['drums_2bar_nade_reduced'] = TrainedModel(drums_nade_reduced_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR +\
                                                       '/checkpoints/drums_2bar_nade.reduced.ckpt')
drums_nade_full_config = configs.CONFIG_MAP['nade-drums_2bar_full']
drums_models['drums_2bar_nade_full'] = TrainedModel(drums_nade_full_config, batch_size=4, checkpoint_dir_or_path=BASE_DIR +\
                                                    '/checkpoints/drums_2bar_nade.full.ckpt')

In [None]:
# Using my onw model
drums_models['drums_2bar_my_full'] = TrainedModel(drums_nade_full_config, batch_size=4, checkpoint_dir_or_path='/content/drive/MyDrive/groove_2bar_vae/train/model_checkpoint/model.ckpt-10340')

In [None]:
#@title Generate 4 samples from the prior of one of the models listed above.
drums_sample_model = "drums_2bar_my_full" #@param ["drums_2bar_oh_lokl", "drums_2bar_oh_hikl", "drums_2bar_nade_reduced", "drums_2bar_nade_full", "drums_2bar_my_full"]
temperature = 0.5 #@param {type:"slider", min:0.1, max:1.5, step:0.1}
drums_samples = drums_models[drums_sample_model].sample(n=4, length=64, temperature=temperature)
for ns in drums_samples:
  play(ns)

In [None]:
input_drums_midi_data = files.upload().values() or input_drums_midi_data

In [None]:
drums_input_seqs = [mm.midi_to_sequence_proto(m) for m in input_drums_midi_data]
extracted_beats = []
for ns in drums_input_seqs:
  extracted_beats.extend(drums_nade_full_config.data_converter.from_tensors(
      drums_nade_full_config.data_converter.to_tensors(ns)[1]))
for i, ns in enumerate(extracted_beats):
  print("Beat", i)
  play(ns)

In [None]:
drums_interp_model = "drums_2bar_my_full" #@param ["drums_2bar_oh_lokl", "drums_2bar_oh_hikl", "drums_2bar_nade_reduced", "drums_2bar_nade_full", "drums_2bar_my_full"]
drums_interp = []
cnt = 0
for start_beat in extracted_beats:
    for end_beat in extracted_beats:
        temperature = 0.5 #@param {type:"slider", min:0.1, max:1.5, step:0.1}
        if start_beat != end_beat:
            drums_interp.append(interpolate(drums_models[drums_interp_model], start_beat, end_beat, num_steps=2, temperature=temperature))
            cnt += 1

print(f"{cnt} beats are generated")     

In [None]:
for item in drums_interp:
    play(item)

In [None]:
import os
os.mkdir('/content/drive/MyDrive/groove/drum_samples')
path = '/content/drive/MyDrive/groove/drum_samples/'
for idx, item in enumerate(drums_interp): 
    download(item, path + f'drum_sample_v{idx+1}.mid')
print('done!')