<a href="https://colab.research.google.com/github/AnonUserGit/TransformerVelGroove2Performance/blob/main/NIME2022_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Demo - Transforming Monotonous Velocity Grooves using Transformer Neural Networks**
---


## Environment setup

In [1]:
%%capture
#@title Setup (Remove "%%capture from top of cell" to see output!)

# !pip install -q condacolab
# import condacolab
# condacolab.install()


# Installing magenta (for note_seq)
!pip install -U -q magenta

# Getting wandb
!pip install -q wandb

# Installing fluidsynth
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth
import ctypes.util
orig_ctypes_util_find_library = ctypes.util.find_library
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return orig_ctypes_util_find_library(lib)
ctypes.util.find_library = proxy_find_library

# Installing and activating environment
#!conda env create -f TransformerGrooveTap2Drum/environment.yml

from google.colab import files
import IPython.display
from IPython.display import Audio
import magenta
import note_seq

In [2]:
#@title Download Data and Source Code
from google.colab import drive
drive.mount('/content/drive')

#@title
# Cloning repository
!git clone --quiet https://github.com/AnonUserGit/TransformerVelGroove2Performance.git

# Unzipping dependencies
!unzip -qq /content/TransformerVelGroove2Performance/dependencies.zip -d .

# Unzipping midi data
!unzip -qq /content/TransformerVelGroove2Performance/groove_midi_examples.zip -d .

# Unzip trained models
!unzip -qq /content/TransformerVelGroove2Performance/trained_models/misunderstood_bush_246-epoch_26.Model.zip -d .
!unzip -qq /content/TransformerVelGroove2Performance/trained_models/rosy_durian_248-epoch_26.Model.zip -d .
!unzip -qq /content/TransformerVelGroove2Performance/trained_models/solar_shadow_247-epoch_41.Model.zip -d .



Mounted at /content/drive


In [3]:
#@title Import Libraries and Utilities

from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import glob

#@title Import libraries and define util functions
import ipywidgets as widgets
import os
import torch
import sys
import note_seq
import pretty_midi as pm
import copy
import wandb
import re
import numpy as np

sys.path.insert(1, "/content/dependencies/BaseGrooveTransformers/")
sys.path.insert(1, "/content/dependencies/hvo_sequence/")

from models.train import *
from models.transformer import GrooveTransformerEncoder
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
from hvo_sequence.io_helpers import note_sequence_to_hvo_sequence
from hvo_sequence.hvo_seq import empty_like

def play(hvo_seq, sf2_path='/content/dependencies/hvo_sequence/hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'):
  audio_seq = hvo_seq.synthesize(sr=44100, sf_path=sf2_path)
  IPython.display.display(IPython.display.Audio(audio_seq, rate=44100))

def fixed_hvo_tsteps(hvo_arr, n_tsteps):
  if hvo_arr.shape[0] > n_tsteps:
    _hvo_arr = hvo_arr[:n_tsteps,:]
  elif hvo_arr.shape[0] < n_tsteps:
    _hvo_arr = np.concatenate((hvo_arr,np.zeros((n_tsteps-hvo_arr.shape[0], hvo_arr.shape[1]))))
  else:
    _hvo_arr = hvo_arr
  return _hvo_arr
  
# find file names
global file_names
file_names = glob.glob("/content/groove_midi_examples/*4-4.mid", recursive = True)


def filename_interface(ID):
  # for selecting midi files interactively down below
  global file_name
  file_name = file_names[ID]
  return file_name

## **Select Model Checkpoint**

In [4]:
model_filename = 'misunderstood_bush_246-epoch_26.Model' #@param ["misunderstood_bush_246-epoch_26.Model", "rosy_durian_248-epoch_26.Model", "solar_shadow_247-epoch_41.Model"]


In [5]:
#@title Load model
%%capture
TRAINED_MODELS_PATH = "/content/"

params = {
    'hopeful':{ 'd_model': 512, 'embedding_sz': 27, 'n_heads': 4,
                      'dim_ff': 64, 'dropout': 0.1708, 'n_layers': 8,
                      'max_len': 32, 'device': 'cpu' },
    'misunderstood':{ 'd_model': 128, 'embedding_sz': 27, 'n_heads': 4,
                          'dim_ff': 128, 'dropout': 0.1038, 'n_layers': 11,
                          'max_len': 32, 'device': 'cpu' },
    'rosy':{ 'd_model': 512, 'embedding_sz': 27, 'n_heads': 4,
                    'dim_ff': 16, 'dropout': 0.1093, 'n_layers': 6,
                    'max_len': 32, 'device': 'cpu' },
    'solar':{ 'd_model': 128, 'embedding_sz': 27, 'n_heads': 1,
                     'dim_ff': 16, 'dropout': 0.1594, 'n_layers': 7,
                     'max_len': 32, 'device': 'cpu' }
}

selected_model_params = params[model_filename.split('_')[0]]

# Load checkpoint
checkpoint = torch.load(os.path.join(TRAINED_MODELS_PATH, model_filename),
                        map_location=torch.device(selected_model_params['device']))

# Initialize model
groove_transformer = GrooveTransformerEncoder(selected_model_params['d_model'],
                                              selected_model_params['embedding_sz'],
                                              selected_model_params['embedding_sz'],
                                              selected_model_params['n_heads'],
                                              selected_model_params['dim_ff'],
                                              selected_model_params['dropout'],
                                              selected_model_params['n_layers'],
                                              selected_model_params['max_len'],
                                              selected_model_params['device'])
# Load model and put in evaluation mode
groove_transformer.load_state_dict(checkpoint['model_state_dict'])
groove_transformer.eval()


In [21]:
#@title Select midi file
interact(filename_interface, ID=int(len(file_names)/2));


interactive(children=(IntSlider(value=545, description='ID', max=1635, min=-545), Output()), _dom_classes=('wi…

In [22]:
#@title Tappify your own drum MIDI file or use an example from the Groove MIDI Dataset
upload_myown_midi_file = False #@param {type:"boolean"}

if upload_myown_midi_file:
  uploaded_file = files.upload()
  FILEPATH = list(uploaded.keys())[0]
else:
  FILEPATH = file_name
  print(file_name)

# Getting HVO representation
gt_midi = pm.PrettyMIDI(FILEPATH)
gt_note_seq = note_seq.midi_to_note_sequence(gt_midi)
gt_hvo_seq = note_sequence_to_hvo_sequence(ns=gt_note_seq, drum_mapping=ROLAND_REDUCED_MAPPING)

# Taking first 2 bars of file, padding with 0 if necessary
gt_hvo_seq.hvo = fixed_hvo_tsteps(gt_hvo_seq.hvo, 32)

tap_hvo_seq = copy.deepcopy(gt_hvo_seq)
tap_hvo_seq.hvo = gt_hvo_seq.flatten_voices()

print("Ground truth:")
play(gt_hvo_seq)
print("Tappified:")
play(tap_hvo_seq)


/content/groove_midi_examples/13_jazz-funk_116_fill_4-4.mid
Ground truth:


Tappified:


In [24]:
#@title Generate prediction from tapped input

# hit_activation = "use_probability_distribution" #@param ["use_threshold", "use_probability_distribution"]
hit_activation_threshold = 0 #@param {type:"slider", min:0, max:1, step:0.1}

# tapped sequence to tensor
tap_hvo_tensor = torch.FloatTensor(tap_hvo_seq.hvo)


#if hit_activation == "use_threshold":
pred_h, pred_v, pred_o = groove_transformer.predict(
  tap_hvo_tensor, use_thres=True, thres=hit_activation_threshold)
#else:
#  pred_h, pred_v, pred_o = groove_transformer.predict(
#    tap_hvo_tensor, use_thres=False, use_pd=True)

prediction_hvo_seq = empty_like(tap_hvo_seq)
prediction_hvo_seq.hvo = np.zeros((32, 27))
prediction_hvo_seq.hits = pred_h.numpy()[0]
prediction_hvo_seq.velocities = pred_v.numpy()[0]
prediction_hvo_seq.offsets = pred_o.numpy()[0]

print("Tapped sequence:")
play(tap_hvo_seq)
print("Generated beat:")
play(prediction_hvo_seq)

Tapped sequence:


Generated beat:
