In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!apt-get install fluidsynth
!pip install transformers==4.21.3 wandb blobfile mpi4py pretty_midi

In [None]:
!ls

## maestro

In [None]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip

In [None]:
!unzip maestro-v3.0.0-midi.zip

In [None]:
!find . -name "*.midi" | wc -l

## MIDI Functions

In [None]:
import os
import json
import argparse
import pretty_midi

import numpy as np
import scipy.io.wavfile as wav

MIN_NOTE_MULTIPLIER = 0.125
MIDI_EXTENSIONS = [".mid", ".midi"]

def load(datapath, pitch_range=(30, 96), velocity_range=(32, 127, 4), fs=1000, tempo=120, augmentation=(1, 1, 1)):
    vocab = set()

    if os.path.isfile(datapath):
        text = load_file(datapath, pitch_range, velocity_range, fs, tempo, augmentation)
        vocab = set(text.split(" "))
    else:
        vocab = load_dir(datapath, pitch_range, velocity_range, fs, tempo, augmentation)

    return vocab

def load_dir(dirpath, pitch_range=(30, 96), velocity_range=(32, 127, 4), fs=1000, tempo=120, augmentation=(1, 1, 1)):
    vocab = set()

    for dir, _ , files in os.walk(dirpath):
        for i, f in enumerate(files):
            filepath = os.path.join(dir, f)

            text = load_file(filepath, pitch_range, velocity_range, fs, tempo, augmentation)
            if text != '':
                vocab = vocab | set(text.split(" "))

    return vocab

def load_file(filepath, pitch_range=(30, 96), velocity_range=(32, 127, 4), fs=1000, tempo=120, augmentation=(1, 1, 1)):
    text = []

    # Check if it is a midi file
    filename, extension = os.path.splitext(filepath)
    if extension.lower() in MIDI_EXTENSIONS:
        print("Encoding file...", filepath)

        # If txt version of the midi already exists, load data from it
        if os.path.isfile(filename + ".txt"):
            with open(filename + ".txt", "r") as midi_txt:
                text = midi_txt.read().split(" ")
        else:
            try:
                midi_data = pretty_midi.PrettyMIDI(filepath)
            except KeyboardInterrupt:
                print("Exiting due to keyboard interrupt")
                quit()
            except:
                return " ".join(text)

            text = midi2text(midi_data, pitch_range, velocity_range, fs, tempo, augmentation)
            with open(filename + ".txt", "w") as midi_txt:
                midi_txt.write(" ".join(text))

    return " ".join(text)

def midi2text(midi_data, pitch_range=(30, 96), velocity_range=(32, 127, 4), fs=1000, tempo=120, augmentation=(1, 1, 1)):
    text = []

    # Parse notes and tempo changes from the midi data
    midi_notes = parse_notes_from_midi(midi_data, fs)

    transpose, time_stretch, velo_stretch = augmentation
    transpose_range    = (-transpose//2 + 1, transpose//2 + 1)
    time_stretch_range = (-time_stretch//2 + 1, time_stretch//2 + 1)
    velo_stretch_range = (-velo_stretch//2 + 1, velo_stretch//2 + 1)

    for i in range(transpose_range[0], transpose_range[1]):
        for j in range(time_stretch_range[0], time_stretch_range[1]):
            for k in range(velo_stretch_range[0], velo_stretch_range[1]):
                last_start = last_duration = last_velocity = 0;

                for start, time_step_notes in sorted(midi_notes.items()):
                    wait_duration = get_note_duration((start - last_start)/fs, tempo, stretch=j)
                    if wait_duration > 0:
                        if wait_duration != last_duration:
                            text.append("d_" + str(wait_duration))
                            last_duration = wait_duration

                        text.append("a")

                    for note in time_step_notes:
                        note_pitch  = clamp_pitch(note["pitch"] + i, pitch_range)
                        note_velocity = clamp_velocity(note["velocity"] + k * 8 * velocity_range[2], velocity_range)
                        note_duration = get_note_duration(note["duration"]/fs, tempo, stretch=j)

                        if note_velocity > 0 and note_duration > 0:
                            if note_velocity != last_velocity:
                                text.append("v_" + str(note_velocity))
                                last_velocity = note_velocity

                            if note_duration != last_duration:
                                text.append("d_" + str(note_duration))
                                last_duration = note_duration

                            text.append("n_" + str(note_pitch))

                    last_start = start

                text.append("\n")

    return text

def parse_notes_from_midi(midi_data, fs):
    notes = {}

    for instrument in midi_data.instruments:
        for note in instrument.notes:
            start, end = int(fs * note.start), int(fs * note.end)

            if start not in notes:
                notes[start] = []

            notes[start].append({
                "pitch": note.pitch,
             "duration": end - start,
             "velocity": note.velocity})

    return notes

def text2midi(text, tempo):
    notes = parse_notes_from_text(text, tempo)

    # Create a PrettyMIDI object
    midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)

    # Create an Instrument instance for a piano instrument
    piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
    piano = pretty_midi.Instrument(program=piano_program)

    # Add notes
    for n in notes:
        piano.notes.append(n)

    midi.instruments.append(piano)

    return midi

def parse_total_duration_from_text(text, tempo=120):
    duration, total_duration = 0, 0
    for token in text.split(" "):
        if token[0] == "a":
            total_duration += duration

        elif token[0] == "d":
            duration = int(token.split("_")[1])

    # Compute duration of shortest note
    min_duration = MIN_NOTE_MULTIPLIER * 60/tempo

    return total_duration * min_duration

def parse_notes_from_text(text, tempo):
    notes = []

    # Set default velocity
    velocity = 100

    # Set default duration
    duration = 8

    # Compute duration of shortest note
    min_duration = MIN_NOTE_MULTIPLIER * 60/tempo

    i = 0
    for token in text.split(" "):
        if token[0] == "a":
            i += duration

        elif token[0] == "n":
            pitch = int(token.split("_")[1])
            note = pretty_midi.Note(velocity, pitch, start=i * min_duration, end=(i + duration) * min_duration)
            notes.append(note)

        elif token[0] == "d":
            duration = int(token.split("_")[1])

        elif token[0] == "v":
            velocity = int(token.split("_")[1])

    return notes

def clamp_velocity(velocity, velocity_range):
    min_velocity, max_velocity, step = velocity_range

    velocity = max(min(velocity, max_velocity), min_velocity)
    velocity = (velocity//step) * step

    return velocity

def clamp_pitch(pitch, pitch_range):
    min, max = pitch_range

    while pitch < min:
        pitch += 12
    while pitch >= max:
        pitch -= 12

    return pitch

def get_note_duration(dt, tempo, stretch=0, max_duration=56, percentage=0.15):
    min_duration = MIN_NOTE_MULTIPLIER * 60/tempo

    dt += dt * percentage * stretch

    # Compute how many 32th notes fit inside the given note
    note_duration = round(dt/min_duration)

    # Clamp note duration
    note_duration = min(note_duration, max_duration)

    return note_duration

def save_vocab(vocab, vocab_path):
    # Create dict to support char to index conversion
    char2idx = { char:i for i,char in enumerate(sorted(vocab)) }

    # Save char2idx encoding as a json file for generate midi later
    with open(vocab_path, "w") as f:
        json.dump(char2idx, f)

def write(text, path, synthesize=False, tempo=120):
    SF2_PATH="soundfonts/salc5light-piano.sf2"

    WAV_32INT_MAX = 2147483648

    midi = text2midi(text, tempo)
    midi.write(path + ".mid")

    if synthesize:
        audio = midi.fluidsynth(sf2_path=SF2_PATH)
        # Convert signal from float to int
        audio = np.int32(audio/np.max(np.abs(audio)) * WAV_32INT_MAX)

        wav.write(path + ".wav", 44100, audio)


## midi2text.py

In [None]:
import glob
from music21 import converter, instrument, note, chord, interval, pitch
import sys
from tqdm import tqdm

def valid_note(note_number):
    return -12*5 <= note_number <= 12*5

def first_note(notes_to_parse):
    for element in notes_to_parse:
        try:
            if isinstance(element, note.Note):
                return int(element.pitch.ps)
            if isinstance(element, chord.Chord):
                return int(element.notes[-1].pitch.ps)
        except:
            pass
    return 60

def midi2text(midis_folder):
    """ Get all the notes and chords from the midi files in the ./midi_songs directory """
    notes = []

    midis = sorted(glob.glob(f"{midis_folder}/*.mid"))

    for file in tqdm(midis):
        try:
            midi = converter.parse(file)

            # Transpose to C
            k = midi.analyze('key')
            i = interval.Interval(k.tonic, pitch.Pitch('C4'))
            midi = midi.transpose(i)
            mode = str(k.mode)

            #print("Parsing %s" % file)

            notes_to_parse = None

            try: # file has instrument parts
                s2 = instrument.partitionByInstrument(midi)
                notes_to_parse = s2.parts[0].recurse() 
            except: # file has notes in a flat structure
                notes_to_parse = midi.flat.notes

            notes.append(f"{mode} =>")

            last_note = first_note(notes_to_parse)

            for element in notes_to_parse:
                try:
                    #print(last_note)
                    if isinstance(element, note.Note):
                        duration = element.duration.quarterLength
                        new_note = int(element.pitch.ps)

                        if duration > 0 and valid_note(new_note - last_note):
                            notes.append(f"{str(new_note - last_note)}|{duration}")
                            last_note = new_note
                    elif isinstance(element, note.Rest):
                        duration = element.duration.quarterLength
                        if 0 < duration < 32:
                            notes.append(f"R|{duration}")
                    elif isinstance(element, chord.Chord):
                        duration = element.duration.quarterLength
                        chord_notes = [int(new_note.pitch.ps)-last_note for new_note in element.notes if valid_note(int(new_note.pitch.ps)-last_note)]
                        if duration > 0 and len(chord_notes) > 0:
                            notes.append('.'.join(map(str,chord_notes)) + "|" + str(duration))
                            last_note = int(element.notes[0].pitch.ps)
                except:
                    pass
            notes.append("\n")
        except:
            pass
    with open('notes.txt', 'w') as filepath:
        filepath.write(" ".join(notes).replace("\n ","\n").strip())


## text2midi.py

In [None]:
import numpy
from music21 import instrument, note, stream, chord
import sys
import numpy as np

def convert_to_float(frac_str):
    try:
        return float(frac_str)
    except ValueError:
        num, denom = frac_str.split('/')
        try:
            leading, num = num.split(' ')
            whole = float(leading)
        except ValueError:
            whole = 0
        frac = float(num) / float(denom)
        return whole - frac if whole < 0 else whole + frac

def note_number_to_name(note_number):
    #print(note_number)
    semis = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']

    # Ensure the note is an int
    note_number = int(np.round(note_number))

    # Get the semitone and the octave, and concatenate to create the name
    output = semis[note_number % 12] + str(note_number//12 - 1)
    #print(output)
    return output


def duration(note_text):
    if "|" in note_text:
        return convert_to_float(note_text.split("|")[-1])
    # It's invalid but's let's return 0.5 to make it more resilient
    else:
        return 0.5

def create_midi(notes_sequence,output_file="output.mid"):
    offset = 0
    output_notes = []

    prev_note = 60
    # create note and chord objects based on the values generated by the model
    for token in notes_sequence:
        # starting token
        if token == "S":
            continue
        token_duration = duration(token)
        token = token.split("|")[0]

        # token is a chord
        if '.' in token:
            notes_in_chord = token.split('.')
            notes = []
            for current_note in notes_in_chord:

                new_note = note.Note(note_number_to_name(prev_note+int(current_note)))
                new_note.storedInstrument = instrument.Piano()
                notes.append(new_note)
            prev_note = int(notes_in_chord[0]) + prev_note
            new_chord = chord.Chord(notes,quarterLength=token_duration)
            new_chord.offset = offset
            output_notes.append(new_chord)
        # token is a rest
        elif "R" in token:
            output_notes.append(note.Rest(quarterLength=token_duration))
        # token is a single note
        else:
            new_note = note.Note(note_number_to_name(prev_note+int(token)),quarterLength=token_duration)
            prev_note = int(token) + prev_note
            new_note.offset = offset
            new_note.storedInstrument = instrument.Piano()
            output_notes.append(new_note)

        # increase offset each iteration so that notes do not stack
        offset += token_duration

    midi_stream = stream.Stream(output_notes)

    midi_stream.write('midi', fp=output_file)


## Data Preprocess

In [None]:
%cd /content
!rm -fr adl-piano-midi

In [None]:
load('adl-piano-midi', augmentation=(1, 1, 1))

In [None]:
!cd..; find adl-piano-midi -name "*.txt" -exec cat {} \; > /content/drive/MyDrive/midi.txt

In [None]:
# !wget https://github.com/lucasnfe/adl-piano-midi/archive/refs/heads/master.zip
# !unzip master.zip; unzip adl-piano-midi-master/midi/adl-piano-midi.zip

In [None]:
# midi2text('adl-piano-midi')
# midi2text('adl-piano-midi/Folk/Nordic Folk/Mari Boine/')

In [None]:
# !cat notes.txt

## Train

In [None]:
!git clone https://github.com/infinfin/text-diffusion

In [None]:
%cd text-diffusion

In [None]:
!git pull

In [None]:
!cp /content/drive/MyDrive/midi.txt /content
!mkdir -p data/midi

In [None]:
# max length
MAX_LEN = 256 - 2
data = []
with open('/content/midi.txt', 'r') as f:
    for _ in f:
        _ = _.strip()
        s = _.split(' ')
        
        while len(s) > MAX_LEN:
            data.append(s[:MAX_LEN])
            s = s[MAX_LEN:]

print(len(data), max(len(_) for _ in data))

with open('data/midi/midi.txt', 'w') as f:
    for _ in data:
        f.write(' '.join(_) + '\n')

In [None]:
!head data/midi/midi.txt

In [None]:
# tokenizing
!python src/utils/custom_tokenizer.py train-word-level data/midi/midi.txt

In [None]:
# split dataset
# !head -n 240000 data/midi/midi.txt > data/midi-train.txt
# !tail -n 9228 data/midi/midi.txt > data/midi-test.txt
!head -n 120000 data/midi/midi.txt > data/midi-train.txt
!tail -n 2878 data/midi/midi.txt > data/midi-test.txt
!wc -l data/midi-train.txt
!wc -l data/midi-test.txt

In [None]:
!sed -i 's/log_interval 2/log_interval 10/g' scripts/run_train.sh
!sed -i 's/save_interval 10 /save_interval 500 /g' scripts/run_train.sh
!grep log_interval scripts/run_train.sh
!grep save_interval scripts/run_train.sh

In [None]:
# choose nn model
!sed -i 's/USE_BERT = 1/USE_BERT = 0/' src/modeling/predictor/transformer_model.py
!grep USE_BERT src/modeling/predictor/transformer_model.py

In [None]:
!rm -fr ckpts; mkdir ckpts
!mkdir -p /content/drive/MyDrive/midi/bert2; ln -s /content/drive/MyDrive/midi/bert2 ckpts/midi

In [None]:
# DSET=${1:-simple}

# GPU=${2:-0}

# INIT_PRETRAINED_MODEL=${3:-"True"}
# USE_PRETRAINED_EMBEDDINGS=${4:-"True"}
# FREEZE_EMBEDDINGS=${5:-"False"}

# LR_ANNEAL_STEPS=${6:-25001}
# LR=${7:-0.0001}

# DIFFUSION_STEPS=${8:-2000}
# NOISE_SCHEDULE=${9:-sqrt}

# BATCH_SIZE=${10:-64}
# SEQ_LEN=${11:-50}

!PYTHONPATH=.:src; TOKENIZERS_PARALLELISM=false; bash scripts/run_train.sh midi 0  False False False  5000 0.0001  2000 sqrt  16  512

In [None]:
!rm -fr ckpts/midi/*000000*
!cp -r ckpts /content/drive/MyDrive/

## Generation

In [None]:
!cp -r /content/drive/MyDrive/ckpts .

In [None]:
# !sed -i 's/BATCH_SIZE=${5:-50}/BATCH_SIZE=${5:-10}/g' scripts/text_sample.sh

In [None]:
!PYTHONPATH=.:src; CUDA_VISIBLE_DEVICES=0; bash scripts/text_sample.sh ckpts/midi/ema_0.9999_002000.pt 1000 10 '' 10
# !PYTHONPATH=.:src; bash scripts/text_sample.sh ckpts/midi/ema_0.9999_001000.pt 20 3

In [None]:
!head ckpts/midi/ema_0.9999_001000.pt.samples_10.steps-1000.clamp-no_clamp.txt
!head ckpts/midi/ema_0.9999_002000.pt.samples_10.steps-1000.clamp-no_clamp.txt

In [None]:
f = 'ckpts/midi/ema_0.9999_001000.pt.samples_10.steps-1000.clamp-no_clamp.txt'
with open(f) as f:
    for i in range(10):
        text = f.readline().strip()
        print(text)
        text = text.replace('[CLS] ', '').replace(' [PAD]', '').replace(' [UNK]', '').replace(' [MASK]', '').replace(' [SEP]', '')
        midi = text2midi(text, tempo=120)
        out = f"{i}.mid"
        midi.write(out)

In [None]:
!gdown "https://drive.google.com/u/0/uc?id=0B5gPxvwx-I4KWjZ2SHZOLU42dHM&export=download"
!unzip SALC5-Light-SF-v2_7.zip

In [None]:
!for f in *.mid; do fluidsynth -ni SalC5Light2.sf2 $f -F $f.wav -r 44100; done

In [None]:
from IPython.display import Audio
display(Audio('9.mid.wav', autoplay=not True))

In [None]:
# import glob
# from IPython.display import Audio
# for f in glob.glob("*.wav"):
#     display(Audio(f, autoplay=not True))
#     break