<a href="https://colab.research.google.com/github/Athanas7a/AI-Crash-Course/blob/master/MMM_LBD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To start, press the INSTALL button, which should take around 2 minutes. **Once the INSTALL cell has completed**, press the RUN DEMO to generate some music!

How to use :

**Note that Google Colab may run out of memory when using a lot of tracks. This will occur more frequently with the 8 bar models. For best results, stick to the 4 bar models or use a local runtime with 32GB.**

1) If you don't have any midi files on hand, use the MIDI Example menu to select some curated examples. Otherwise use the ADD MIDI button to load your own midi file. Note that the demo only handles midi files with a 4/4 time signature.

2) To use the bar fill model, select one or more bars by clicking directly on the bars in the pianoroll.

3) To use the multi-track model either click add track, which will add a new track where you can select the instrument and density, or click the resample toggle on a pre-existing track to resample that track. The first entries in the instrument menu allow for a set of MIDI instruments to be selected. For example, if guitar is selected, the model can pick from any of the guitar general MIDI instruments (see this wikipedia article for an overview https://en.wikipedia.org/wiki/General_MIDI). It is also possible to select a specific instrument by using entries towards the bottom of the instrument menu. For example, rather than allowing the model to choose which guitar, we can specify that we want the model to use the Overdriven Guitar.

4) To use the 8 bar model, use the NBars menu to select 8.

In [None]:
#@title INSTALL
from tqdm import tqdm
from google.colab import output
pbar = tqdm(total=7)

def update(x):
  pbar.set_description(x)
  output.clear()
  pbar.update(1)

%cd /content

update("downloading code")
!rm -rf dataset_builder_2_minimal &> /dev/null
!rm -rf dataset_builder_2.zip &> /dev/null
!wget https://www.sfu.ca/~jeffe/LBD_DEMO/dataset_builder_2.zip &> /dev/null
!unzip -qq -o dataset_builder_2.zip &> /dev/null

update("installing dependencies")
!apt-get install libprotobuf-dev protobuf-compiler &> /dev/null

%cd /content/dataset_builder_2_minimal/src/dataset_builder_2/protobuf
update("building protobuf")
!protoc --cpp_out . midi.proto &> /dev/null

%cd ../../..
update("compiling code (this takes a couple minutes)")
%pip install . &> /dev/null

update("installing python modules")
!pip install transformers==3.3.1 &> /dev/null
!pip install pyFluidSynth &> /dev/null
!apt install fluidsynth &> /dev/null
!pip install midi2audio &> /dev/null
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2 &> /dev/null

update("downloading model weights")
import os
from subprocess import call
import dataset_builder_2 as db

ckpt_map_raw = {
  (4,db.MODEL_TYPE.TRACK_MODEL) : (
    db.ENCODER_TYPE.TRACK_DENSITY_ENCODER,
    ("TRACK_DENSITY_ENCODER_gpt2_version3_Aug_04_14_29_False", 405000)
  ),
  (4,db.MODEL_TYPE.BAR_INFILL_MODEL) : (
    db.ENCODER_TYPE.TRACK_BAR_FILL_DENSITY_ENCODER,
    ("TRACK_BAR_FILL_DENSITY_ENCODER_gpt2_version3_Aug_05_09_15_False", 395000)
  ),
  (8,db.MODEL_TYPE.TRACK_MODEL) : (
    db.ENCODER_TYPE.TRACK_DENSITY_ENCODER,
    ("TRACK_DENSITY_ENCODER_gpt2_version3_Aug_13_13_31_False_num_bars_8_6", 240000)
  ),
  (8,db.MODEL_TYPE.BAR_INFILL_MODEL) : (
    db.ENCODER_TYPE.TRACK_BAR_FILL_DENSITY_ENCODER,
    ("TRACK_BAR_FILL_DENSITY_ENCODER_gpt2_version3_Aug_13_13_24_False_num_bars_8_6", 175000)
  )
}

def download_model_from_web(name,step,force=False):
  file_list = ["config.json", "pytorch_model.bin", "scheduler.pt", "training_args.bin"] # don't need optimizer
  ckpt_path = "checkpoints/{}/checkpoint-{}".format(name,step)
  os.makedirs(ckpt_path, exist_ok=True)
  for file in file_list:
    path = os.path.join(ckpt_path,file)
    if force or not os.path.exists(path):
      call("wget http://www.sfu.ca/~jeffe/{}/{} -O {}".format(ckpt_path,file,path), shell=True)
  return ckpt_path

ckpt_map = {}
for k,(et,(name,step)) in ckpt_map_raw.items():
  ckpt_map[k] = (et,download_model_from_web(name,step))

update("finished!")


In [None]:
#@title RUN DEMO

# run the weight download again

import os
from subprocess import call
import dataset_builder_2 as db

!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2 &> /dev/null

ckpt_map_raw = {
  (4,db.MODEL_TYPE.TRACK_MODEL) : (
    db.ENCODER_TYPE.TRACK_DENSITY_ENCODER,
    ("TRACK_DENSITY_ENCODER_gpt2_version3_Aug_04_14_29_False", 405000)
  ),
  (4,db.MODEL_TYPE.BAR_INFILL_MODEL) : (
    db.ENCODER_TYPE.TRACK_BAR_FILL_DENSITY_ENCODER,
    ("TRACK_BAR_FILL_DENSITY_ENCODER_gpt2_version3_Aug_05_09_15_False", 395000)
  ),
  (8,db.MODEL_TYPE.TRACK_MODEL) : (
    db.ENCODER_TYPE.TRACK_DENSITY_ENCODER,
    ("TRACK_DENSITY_ENCODER_gpt2_version3_Aug_13_13_31_False_num_bars_8_6", 240000)
  ),
  (8,db.MODEL_TYPE.BAR_INFILL_MODEL) : (
    db.ENCODER_TYPE.TRACK_BAR_FILL_DENSITY_ENCODER,
    ("TRACK_BAR_FILL_DENSITY_ENCODER_gpt2_version3_Aug_13_13_24_False_num_bars_8_6", 175000)
  )
}

def download_model_from_web(name,step,force=False):
  file_list = ["config.json", "pytorch_model.bin", "scheduler.pt", "training_args.bin"] # don't need optimizer
  ckpt_path = "checkpoints/{}/checkpoint-{}".format(name,step)
  os.makedirs(ckpt_path, exist_ok=True)
  for file in file_list:
    path = os.path.join(ckpt_path,file)
    if force or not os.path.exists(path):
      call("wget http://www.sfu.ca/~jeffe/{}/{} -O {}".format(ckpt_path,file,path), shell=True)
  return ckpt_path

ckpt_map = {}
for k,(et,(name,step)) in ckpt_map_raw.items():
  ckpt_map[k] = (et,download_model_from_web(name,step))



from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

import os
import re
import base64
import copy
import time
import glob
import json
import torch
import torch.nn.functional as F
import numpy as np
from subprocess import call
from transformers import *
import dataset_builder_2 as db

from midi2audio import FluidSynth

import logging
logging.disable(logging.CRITICAL)

def generate(status, piece, temperature, batch_size, verbose):

  prompt, ctrl, et, ckpt_path, order = db.prepare_generate(
    status, piece, temperature, batch_size, verbose, ckpt_map)

  encoder = db.getEncoder(et)
  model = GPT2LMHeadModel.from_pretrained(ckpt_path)

  seqs = [copy.deepcopy(prompt) for _ in range(batch_size)]
  input_ids = torch.from_numpy(np.array(seqs))
  pkv = None

  ctrl.start(batch_size)
  while not ctrl.all_finished():

    outputs = model(input_ids=input_ids, past_key_values=pkv)
    logits = outputs[0][:,-1,:]
    pkv = outputs[1]

    if temperature != 1.0:
      logits /= temperature

    next_tokens = []
    for i in range(batch_size):

      # modify the logits if necessary
      if len(seqs[i]) > 0 and ctrl.check_trigger(i, seqs[i][-1]):
        mask = np.array(ctrl.get_mask(i))
        logits[i][mask==0] += -float("Inf")
        ctrl.increment(i)

    probs = F.softmax(logits, dim=-1)
    input_ids = torch.multinomial(probs, num_samples=1)

    for i in range(batch_size):
      if not ctrl.is_finished(i):
        seqs[i].append( input_ids[i][0].item() )

    if verbose:
      print( encoder.rep.pretty(seqs[0][-1]) )

  # convert the seqs to midi or protobuf
  pieces = []
  for i in range(batch_size):
    piece = encoder.tokens_to_json(seqs[i])
    if len(order):
      piece = db.reorder_tracks(piece, order)
    pieces.append( piece )

  return pieces

# read the demo.html from the web
import urllib.request
url = 'http://www.sfu.ca/~jeffe/LBD_DEMO/demo_v2.html'
response = urllib.request.urlopen(url)
html = response.read().decode("utf-8")

# read demo.html from local
#with open("../demo.html", "r") as f:
#  html = f.read()

def read_track_map():
  with open("track_map.json", "r") as f:
    return json.load(f)

def write_track_map(x):
  with open("track_map.json", "w") as f:
    json.dump(x,f)

def get_current_midi():
  with open("current_midi.json", "r") as f:
    return json.load(f)

def save_current_midi(midi_json):
  with open("current_midi.json", "w") as f:
    json.dump(midi_json, f)

def save_status(status):
  with open("current_status.json", "w") as f:
    json.dump(status, f)

def update_gui_midi(midi_json):
  assert isinstance(midi_json, dict)
  output.eval_js('''update_midi(JSON.parse('{}'))'''.format(json.dumps(midi_json)))

def generate_callback(status):
  midi_json = get_current_midi()
  save_status(status)

  temperature = status["temperature"]
  tempo = status["tempo"]

  valid_status = {"tracks" : []}
  for track in status.get("tracks",[]):
    track.pop("mute")
    track.pop("solo")
    track["track_type"] = 10
    if (track["instrument"] != "no_drums" and "drum" in track["instrument"]):
      track["track_type"] = 11
    if (track["instrument"] == "any"):
      track["track_type"] = 12
    valid_status["tracks"].append( track )

  status = json.dumps(valid_status)
  piece = json.dumps(midi_json)
  output = generate(status, piece, temperature, 1, False)
  midi_json = json.loads(output[0])
  midi_json["tempo"] = tempo

  # update the midi
  update_gui_midi(midi_json)

  # save the midi
  save_current_midi(midi_json)

# this should work now basically
def mix_tracks_in_json(midi_json, levels=None):
  AUDIO_LEVELS = [12,24,36,48,60,72,84,96,108,120]
  for track_num, track in enumerate(midi_json.get("tracks",[])):
    for bar in track.get("bars",[]):
      for event_index in bar.get("events",[]):
        event = midi_json["events"][event_index]
        if event["velocity"] > 0:
          audio_level = AUDIO_LEVELS[8]
          if levels is not None:
            audio_level = AUDIO_LEVELS[levels[track_num]]
          event["velocity"] = audio_level

def play_callback(status):
  midi_json = get_current_midi()
  tracks = []
  for track in status["tracks"]:
    tid = int(track["track_id"])
    if track["solo"]:
      tracks = [ tid ]
      break
    elif not track["mute"]:
      tracks.append( tid )

  encoder = db.TrackEncoder()

  midi_json["tempo"] = status["tempo"]
  #mix_tracks_in_json(midi_json)
  raw = json.dumps(midi_json)
  bars_to_keep = list(range(status["nbars"]))
  raw = db.prune_tracks(raw, tracks, bars_to_keep)
  encoder.json_to_midi(raw, "current.mid")
  FluidSynth("font.sf2").midi_to_audio('current.mid', 'current.wav')

  # set the src and play
  sound = open("current.wav", "rb").read()
  sound_encoded = base64.b64encode(sound).decode('ascii')
  script = '''<script type="text/javascript">
  var audio = document.querySelector("#beep");
  audio.src = "data:audio/wav;base64,{raw_audio}";
  audio.play();
  </script>'''.format(raw_audio=sound_encoded)
  display(HTML(script))

def add_midi_callback(status, raw):
  data = re.search(r'base64,(.*)', raw).group(1)
  with open("input.mid", "wb") as f:
    f.write(base64.b64decode(data))

  enc = db.TrackEncoder()
  midi_json = json.loads(enc.midi_to_json("input.mid"))
  bars_to_keep = list(range(status["nbars"]))
  midi_json = json.loads(db.prune_empty_tracks(json.dumps(midi_json), bars_to_keep))
  if len(midi_json.get("tracks",[])) == 0:
    output.eval_js('''build_snackbar("Invalid MIDI file. Make sure each track has {} bars.")'''.format(status["nbars"]))
    return

  if not "tempo" in midi_json:
    midi_json["tempo"] = 120

  # update the midi
  update_gui_midi(midi_json)

  # save the midi
  save_current_midi(midi_json)

# =============================================
# setup

# start out with blank midi
save_current_midi({})

GOOGLE_COLAB = True

# register python callbacks
if GOOGLE_COLAB:
  output.register_callback('generate_callback', generate_callback)
  output.register_callback('play_callback', play_callback)
  output.register_callback('add_midi_callback', add_midi_callback)
  output.register_callback('save_current_midi', save_current_midi)

  display(HTML(html))
  output.eval_js("start_up()");
  display(Javascript("google.colab.output.setIframeHeight('800px');"))

<IPython.core.display.Javascript object>