# Installing dependencies

In [1]:
# !pip install fairseq
#==0.10.2

In [2]:
# !pip install fairseq==0.10.1

In [3]:
import torch
torch.__version__

  from .autonotebook import tqdm as notebook_tqdm


'1.13.1+cu116'

In [4]:
#!pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

In [5]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Importing appropriate Fairseq version from Github

https://github.com/microsoft/muzic/blob/main/musicbert/eval_nsp.py

https://github.com/microsoft/muzic/issues/37

In [6]:
# !pip install fairseq==0.12.2

In [7]:
if 'PYTHONPATH' in os.environ:
    os.environ['PYTHONPATH'] += ":/app/work/TFG/fairseq"
else:
    os.environ['PYTHONPATH'] = "/app/work/TFG/fairseq"

get_ipython().system(' echo $PYTHONPATH')


/app/work/TFG/fairseq


In [8]:
import fairseq

## Set up music21 for Colab (only for displaying MIDI files)



Click the play button below to upgrade music21 and install MuseScore to use in Google Colab

In [9]:
#!pip install --upgrade music21

In [10]:
#!add-apt-repository ppa:mscore-ubuntu/mscore-stable -y
#!apt-get update
#!apt-get install musescore

In [11]:
#!apt-get install xvfb

## Importing Dependencies

In [12]:
#from music21 import *
#us = environment.UserSettings()
#us['musescoreDirectPNGPath'] = '/usr/bin/mscore'
#us['directoryScratch'] = '/tmp'

# music21 is Open Source under the BSD License
# Copyright (c) 2006-22 Michael Scott Asato Cuthbert and cuthbertLab
# Support music21 by citing it in your research or produts:
#
#     Cuthbert, Michael Scott.  
#     _music21: a Toolkit for Computer-Aided Music Research_
#     https://web.mit.edu/music21
#     2006-22

In [13]:
from tqdm import tqdm
from enum import Enum
from itertools import chain

## Source code

In [14]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#

import os
import sys
import io
import zipfile
import miditoolkit
import random
import time
import math
import signal
import hashlib
from multiprocessing import Pool, Lock, Manager

In [15]:
import numpy as np

### Setting `bar_max` used in `encoding_to_str` function

If `bar_max` is too low song gets clipped while encoding to Octuple tokens from MIDI

In [16]:
bar_max = 256 #@param {type:"integer"}

In [17]:
pos_resolution = 16  # per beat (quarter note)

velocity_quant = 4
tempo_quant = 12  # 2 ** (1 / 12)
min_tempo = 16
max_tempo = 256
duration_max = 8  # 2 ** 8 * beat
max_ts_denominator = 6  # x/1 x/2 x/4 ... x/64
max_notes_per_bar = 2  # 1/64 ... 128/64
beat_note_factor = 4  # In MIDI format a note is always 4 beats
deduplicate = True
filter_symbolic = False
filter_symbolic_ppl = 16
trunc_pos = 2 ** 16  # approx 30 minutes (1024 measures)
sample_len_max = 1000  # window length max
sample_overlap_rate = 4
ts_filter = False
pool_num = 24
max_inst = 127
max_pitch = 127
max_velocity = 127

data_zip = None
output_file = None


In [18]:
lock_file = Lock()
lock_write = Lock()
lock_set = Lock()
manager = Manager()
midi_dict = manager.dict()

In [19]:
# (0 Measure, 1 Pos, 2 Program, 3 Pitch, 4 Duration, 5 Velocity, 6 TimeSig, 7 Tempo)
# (Measure, TimeSig)
# (Pos, Tempo)
# Percussion: Program=128 Pitch=[128,255]

# drum ins -> 6 ins kick, cymbal, snare, .... 

In [20]:

ts_dict = dict()
ts_list = list()
for i in range(0, max_ts_denominator + 1):  # 1 ~ 64
    for j in range(1, ((2 ** i) * max_notes_per_bar) + 1):
        ts_dict[(j, 2 ** i)] = len(ts_dict)
        ts_list.append((j, 2 ** i))
dur_enc = list()
dur_dec = list()
for i in range(duration_max):
    for j in range(pos_resolution):
        dur_dec.append(len(dur_enc))
        for k in range(2 ** i):
            dur_enc.append(len(dur_dec) - 1)

In [21]:
print(ts_dict)
print(ts_list)

{(1, 1): 0, (2, 1): 1, (1, 2): 2, (2, 2): 3, (3, 2): 4, (4, 2): 5, (1, 4): 6, (2, 4): 7, (3, 4): 8, (4, 4): 9, (5, 4): 10, (6, 4): 11, (7, 4): 12, (8, 4): 13, (1, 8): 14, (2, 8): 15, (3, 8): 16, (4, 8): 17, (5, 8): 18, (6, 8): 19, (7, 8): 20, (8, 8): 21, (9, 8): 22, (10, 8): 23, (11, 8): 24, (12, 8): 25, (13, 8): 26, (14, 8): 27, (15, 8): 28, (16, 8): 29, (1, 16): 30, (2, 16): 31, (3, 16): 32, (4, 16): 33, (5, 16): 34, (6, 16): 35, (7, 16): 36, (8, 16): 37, (9, 16): 38, (10, 16): 39, (11, 16): 40, (12, 16): 41, (13, 16): 42, (14, 16): 43, (15, 16): 44, (16, 16): 45, (17, 16): 46, (18, 16): 47, (19, 16): 48, (20, 16): 49, (21, 16): 50, (22, 16): 51, (23, 16): 52, (24, 16): 53, (25, 16): 54, (26, 16): 55, (27, 16): 56, (28, 16): 57, (29, 16): 58, (30, 16): 59, (31, 16): 60, (32, 16): 61, (1, 32): 62, (2, 32): 63, (3, 32): 64, (4, 32): 65, (5, 32): 66, (6, 32): 67, (7, 32): 68, (8, 32): 69, (9, 32): 70, (10, 32): 71, (11, 32): 72, (12, 32): 73, (13, 32): 74, (14, 32): 75, (15, 32): 76, (1

In [22]:
class timeout:
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, exc_type, value, traceback):
        signal.alarm(0)


def t2e(x):
    assert x in ts_dict, 'unsupported time signature: ' + str(x)
    return ts_dict[x]


def e2t(x):
    return ts_list[x]


def d2e(x):
    return dur_enc[x] if x < len(dur_enc) else dur_enc[-1]


def e2d(x):
    return dur_dec[x] if x < len(dur_dec) else dur_dec[-1]


def v2e(x):
    return x // velocity_quant


def e2v(x):
    return (x * velocity_quant) + (velocity_quant // 2)


def b2e(x):
    x = max(x, min_tempo)
    x = min(x, max_tempo)
    x = x / min_tempo
    e = round(math.log2(x) * tempo_quant)
    return e


def e2b(x):
    return 2 ** (x / tempo_quant) * min_tempo

In [23]:
def time_signature_reduce(numerator, denominator):
    # reduction (when denominator is too large)
    while denominator > 2 ** max_ts_denominator and denominator % 2 == 0 and numerator % 2 == 0:
        denominator //= 2
        numerator //= 2
    # decomposition (when length of a bar exceed max_notes_per_bar)
    while numerator > max_notes_per_bar * denominator:
        for i in range(2, numerator + 1):
            if numerator % i == 0:
                numerator //= i
                break
    return numerator, denominator


def writer(file_name, output_str_list):
    # note: parameter "file_name" is reserved for patching
    with open(output_file, 'a') as f:
        for output_str in output_str_list:
            f.write(output_str + '\n')

In [24]:
def gen_dictionary(file_name):
    num = 0
    with open(file_name, 'w') as f:
        for j in range(bar_max):
            print('<0-{}>'.format(j), num, file=f)
        for j in range(beat_note_factor * max_notes_per_bar * pos_resolution):
            print('<1-{}>'.format(j), num, file=f)
        for j in range(max_inst + 1 + 1):
            # max_inst + 1 for percussion
            print('<2-{}>'.format(j), num, file=f)
        for j in range(2 * max_pitch + 1 + 1):
            # max_pitch + 1 ~ 2 * max_pitch + 1 for percussion
            print('<3-{}>'.format(j), num, file=f)
        for j in range(duration_max * pos_resolution):
            print('<4-{}>'.format(j), num, file=f)
        for j in range(v2e(max_velocity) + 1):
            print('<5-{}>'.format(j), num, file=f)
        for j in range(len(ts_list)):
            print('<6-{}>'.format(j), num, file=f)
        for j in range(b2e(max_tempo) + 1):
            print('<7-{}>'.format(j), num, file=f)

In [25]:
def MIDI_to_encoding(midi_obj):
    def time_to_pos(t):
        return round(t * pos_resolution / midi_obj.ticks_per_beat)
    notes_start_pos = [time_to_pos(j.start)
                       for i in midi_obj.instruments for j in i.notes]
    if len(notes_start_pos) == 0:
        return list()
    max_pos = min(max(notes_start_pos) + 1, trunc_pos)
    pos_to_info = [[None for _ in range(4)] for _ in range(
        max_pos)]  # (Measure, TimeSig, Pos, Tempo)
    tsc = midi_obj.time_signature_changes
    tpc = midi_obj.tempo_changes
    for i in range(len(tsc)):
        for j in range(time_to_pos(tsc[i].time), time_to_pos(tsc[i + 1].time) if i < len(tsc) - 1 else max_pos):
            if j < len(pos_to_info):
                pos_to_info[j][1] = t2e(time_signature_reduce(
                    tsc[i].numerator, tsc[i].denominator))
    for i in range(len(tpc)):
        for j in range(time_to_pos(tpc[i].time), time_to_pos(tpc[i + 1].time) if i < len(tpc) - 1 else max_pos):
            if j < len(pos_to_info):
                pos_to_info[j][3] = b2e(tpc[i].tempo)
    for j in range(len(pos_to_info)):
        if pos_to_info[j][1] is None:
            # MIDI default time signature
            pos_to_info[j][1] = t2e(time_signature_reduce(4, 4))
        if pos_to_info[j][3] is None:
            pos_to_info[j][3] = b2e(120.0)  # MIDI default tempo (BPM)
    cnt = 0
    bar = 0
    measure_length = None
    for j in range(len(pos_to_info)):
        ts = e2t(pos_to_info[j][1])
        if cnt == 0:
            measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
        pos_to_info[j][0] = bar
        pos_to_info[j][2] = cnt
        cnt += 1
        if cnt >= measure_length:
            assert cnt == measure_length, 'invalid time signature change: pos = {}'.format(
                j)
            cnt -= measure_length
            bar += 1
    encoding = []
    start_distribution = [0] * pos_resolution
    for inst in midi_obj.instruments:
        for note in inst.notes:
            if time_to_pos(note.start) >= trunc_pos:
                continue
            start_distribution[time_to_pos(note.start) % pos_resolution] += 1
            info = pos_to_info[time_to_pos(note.start)]
            encoding.append((info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch +
                             1 if inst.is_drum else note.pitch, d2e(time_to_pos(note.end) - time_to_pos(note.start)), v2e(note.velocity), info[1], info[3]))
    if len(encoding) == 0:
        return list()
    tot = sum(start_distribution)
    start_ppl = 2 ** sum((0 if x == 0 else -(x / tot) *
                          math.log2((x / tot)) for x in start_distribution))
    # filter unaligned music
    if filter_symbolic:
        assert start_ppl <= filter_symbolic_ppl, 'filtered out by the symbolic filter: ppl = {:.2f}'.format(
            start_ppl)
    encoding.sort()
    return encoding

In [26]:
def encoding_to_MIDI(encoding):
    # TODO: filter out non-valid notes and error handling
    bar_to_timesig = [list()
                      for _ in range(max(map(lambda x: x[0], encoding)) + 1)]
    for i in encoding:
        bar_to_timesig[i[0]].append(i[6])
    bar_to_timesig = [max(set(i), key=i.count) if len(
        i) > 0 else None for i in bar_to_timesig]
    for i in range(len(bar_to_timesig)):
        if bar_to_timesig[i] is None:
            bar_to_timesig[i] = t2e(time_signature_reduce(
                4, 4)) if i == 0 else bar_to_timesig[i - 1]
    bar_to_pos = [None] * len(bar_to_timesig)
    cur_pos = 0
    for i in range(len(bar_to_pos)):
        bar_to_pos[i] = cur_pos
        ts = e2t(bar_to_timesig[i])
        measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
        cur_pos += measure_length
    pos_to_tempo = [list() for _ in range(
        cur_pos + max(map(lambda x: x[1], encoding)))]
    for i in encoding:
        pos_to_tempo[bar_to_pos[i[0]] + i[1]].append(i[7])
    pos_to_tempo = [round(sum(i) / len(i)) if len(i) >
                    0 else None for i in pos_to_tempo]
    for i in range(len(pos_to_tempo)):
        if pos_to_tempo[i] is None:
            pos_to_tempo[i] = b2e(120.0) if i == 0 else pos_to_tempo[i - 1]
    midi_obj = miditoolkit.midi.parser.MidiFile()

    def get_tick(bar, pos):
        return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution
    midi_obj.instruments = [miditoolkit.containers.Instrument(program=(
        0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)]
    for i in encoding:
        start = get_tick(i[0], i[1])
        program = i[2]
        pitch = (i[3] - 128 if program == 128 else i[3])
        duration = get_tick(0, e2d(i[4]))
        if duration == 0:
            duration = 1
        end = start + duration
        velocity = e2v(i[5])
        midi_obj.instruments[program].notes.append(miditoolkit.containers.Note(
            start=start, end=end, pitch=pitch, velocity=velocity))
    midi_obj.instruments = [
        i for i in midi_obj.instruments if len(i.notes) > 0]
    cur_ts = None
    for i in range(len(bar_to_timesig)):
        new_ts = bar_to_timesig[i]
        if new_ts != cur_ts:
            numerator, denominator = e2t(new_ts)
            midi_obj.time_signature_changes.append(miditoolkit.containers.TimeSignature(
                numerator=numerator, denominator=denominator, time=get_tick(i, 0)))
            cur_ts = new_ts
    cur_tp = None
    for i in range(len(pos_to_tempo)):
        new_tp = pos_to_tempo[i]
        if new_tp != cur_tp:
            tempo = e2b(new_tp)
            midi_obj.tempo_changes.append(
                miditoolkit.containers.TempoChange(tempo=tempo, time=get_tick(0, i)))
            cur_tp = new_tp
    return midi_obj

In [27]:
def get_hash(encoding):
    # add i[4] and i[5] for stricter match
    midi_tuple = tuple((i[2], i[3]) for i in encoding)
    midi_hash = hashlib.md5(str(midi_tuple).encode('ascii')).hexdigest()
    return midi_hash

In [28]:
def F(file_name):
    try_times = 10
    midi_file = None
    for _ in range(try_times):
        try:
            lock_file.acquire()
            with data_zip.open(file_name) as f:
                # this may fail due to unknown bug
                midi_file = io.BytesIO(f.read())
        except BaseException as e:
            try_times -= 1
            time.sleep(1)
            if try_times == 0:
                print('ERROR(READ): ' + file_name +
                      ' ' + str(e) + '\n', end='')
                return None
        finally:
            lock_file.release()
    try:
        with timeout(seconds=600):
            midi_obj = miditoolkit.midi.parser.MidiFile(file=midi_file)
        # check abnormal values in parse result
        assert all(0 <= j.start < 2 ** 31 and 0 <= j.end < 2 **
                   31 for i in midi_obj.instruments for j in i.notes), 'bad note time'
        assert all(0 < j.numerator < 2 ** 31 and 0 < j.denominator < 2 **
                   31 for j in midi_obj.time_signature_changes), 'bad time signature value'
        assert 0 < midi_obj.ticks_per_beat < 2 ** 31, 'bad ticks per beat'
    except BaseException as e:
        print('ERROR(PARSE): ' + file_name + ' ' + str(e) + '\n', end='')
        return None
    midi_notes_count = sum(len(inst.notes) for inst in midi_obj.instruments)
    if midi_notes_count == 0:
        print('ERROR(BLANK): ' + file_name + '\n', end='')
        return None
    try:
        e = MIDI_to_encoding(midi_obj)
        if len(e) == 0:
            print('ERROR(BLANK): ' + file_name + '\n', end='')
            return None
        if ts_filter:
            allowed_ts = t2e(time_signature_reduce(4, 4))
            if not all(i[6] == allowed_ts for i in e):
                print('ERROR(TSFILT): ' + file_name + '\n', end='')
                return None
        if deduplicate:
            duplicated = False
            dup_file_name = ''
            midi_hash = '0' * 32
            try:
                midi_hash = get_hash(e)
            except BaseException as e:
                pass
            lock_set.acquire()
            if midi_hash in midi_dict:
                dup_file_name = midi_dict[midi_hash]
                duplicated = True
            else:
                midi_dict[midi_hash] = file_name
            lock_set.release()
            if duplicated:
                print('ERROR(DUPLICATED): ' + midi_hash + ' ' +
                      file_name + ' == ' + dup_file_name + '\n', end='')
                return None
        output_str_list = []
        sample_step = max(round(sample_len_max / sample_overlap_rate), 1)
        for p in range(0 - random.randint(0, sample_len_max - 1), len(e), sample_step):
            L = max(p, 0)
            R = min(p + sample_len_max, len(e)) - 1
            bar_index_list = [e[i][0]
                              for i in range(L, R + 1) if e[i][0] is not None]
            bar_index_min = 0
            bar_index_max = 0
            if len(bar_index_list) > 0:
                bar_index_min = min(bar_index_list)
                bar_index_max = max(bar_index_list)
            offset_lower_bound = -bar_index_min
            offset_upper_bound = bar_max - 1 - bar_index_max
            # to make bar index distribute in [0, bar_max)
            bar_index_offset = random.randint(
                offset_lower_bound, offset_upper_bound) if offset_lower_bound <= offset_upper_bound else offset_lower_bound
            e_segment = []
            for i in e[L: R + 1]:
                if i[0] is None or i[0] + bar_index_offset < bar_max:
                    e_segment.append(i)
                else:
                    break
            tokens_per_note = 8
            output_words = (['<s>'] * tokens_per_note) \
                + [('<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) if k is not None else '<unk>') for i in e_segment for j, k in enumerate(i)] \
                + (['</s>'] * (tokens_per_note - 1)
                   )  # tokens_per_note - 1 for append_eos functionality of binarizer in fairseq
            output_str_list.append(' '.join(output_words))

        # no empty
        if not all(len(i.split()) > tokens_per_note * 2 - 1 for i in output_str_list):
            print('ERROR(ENCODE): ' + file_name + ' ' + str(e) + '\n', end='')
            return False
        try:
            lock_write.acquire()
            writer(file_name, output_str_list)
        except BaseException as e:
            print('ERROR(WRITE): ' + file_name + ' ' + str(e) + '\n', end='')
            return False
        finally:
            lock_write.release()
        print('SUCCESS: ' + file_name + '\n', end='')
        return True
    except BaseException as e:
        print('ERROR(PROCESS): ' + file_name + ' ' + str(e) + '\n', end='')
        return False
    print('ERROR(GENERAL): ' + file_name + '\n', end='')
    return False


def G(file_name):
    try:
        return F(file_name)
    except BaseException as e:
        print('ERROR(UNCAUGHT): ' + file_name + '\n', end='')
        return False

In [29]:
def str_to_encoding(s):
    encoding = [int(i[3: -1]) for i in s.split() if 's' not in i]
    tokens_per_note = 8
    assert len(encoding) % tokens_per_note == 0
    encoding = [tuple(encoding[i + j] for j in range(tokens_per_note))
                for i in range(0, len(encoding), tokens_per_note)]
    return encoding


In [30]:
def encoding_to_str(e, bar_max = bar_max):
    bar_index_offset = 0
    p = 0
    tokens_per_note = 8
    return ' '.join((['<s>'] * tokens_per_note)
                    + ['<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) for i in e[p: p +
                                                                                            sample_len_max] if i[0] + bar_index_offset < bar_max for j, k in enumerate(i)]
                    + (['</s>'] * (tokens_per_note
                                   - 1)))   # 8 - 1 for append_eos functionality of binarizer in fairseq

## Preprocessing zip to OctupleMIDI

In [31]:
#import os
#import gdown

#midi_dir = os.path.join("midi")
#os.makedirs(midi_dir, exist_ok=True)

#file_id1 = "14GSG63ynI-iUdQI_eqnUUppi8H0Djgl1"
#url1 = f"https://drive.google.com/uc?id={file_id1}"
#output_filename1 = "lmd_matched3.zip"

#output_path1 = os.path.join(midi_dir, output_filename1)
#gdown.download(url1, output_path1, quiet=False)


#file_id2 = "1iW8B0BlzL0pKWuDIb-KnckH3pIwGK096"
#url2 = f"https://drive.google.com/uc?id={file_id2}"
#output_filename2 = "lmd_matched.zip"

#output_path2 = os.path.join(midi_dir, output_filename2)
#gdown.download(url2, output_path2, quiet=False)


In [32]:
#file_id3 = "1DbCUGMVSzuf971KObdmxZDtzre_zmDBU"
#url3 = f"https://drive.google.com/uc?id={file_id3}"
#output_filename3 = "iwantitthatway.mid"

#output_path3 = os.path.join(midi_dir, output_filename3)
#gdown.download(url3, output_path3, quiet=False)


## Working on single MIDI

In [33]:
 sample_midi_path = '/app/work/midi/iwantitthatway.mid'

In [34]:
 midi_obj = miditoolkit.midi.parser.MidiFile(sample_midi_path)

In [35]:
 print(midi_obj)

ticks per beat: 120
max tick: 42241
tempo changes: 5
time sig: 1
key sig: 1
markers: 11
lyrics: False
instruments: 16


In [36]:
 assert all(0 <= j.start < 2 ** 31 and 0 <= j.end < 2 **
                    31 for i in midi_obj.instruments for j in i.notes), 'bad note time'
 assert all(0 < j.numerator < 2 ** 31 and 0 < j.denominator < 2 **
             31 for j in midi_obj.time_signature_changes), 'bad time signature value'
 assert 0 < midi_obj.ticks_per_beat < 2 ** 31, 'bad ticks per beat'

In [37]:
 for ins in midi_obj.instruments:
   print(ins)

Instrument(program=1, is_drum=False, name="")
Instrument(program=33, is_drum=False, name="")
Instrument(program=18, is_drum=False, name="")
Instrument(program=68, is_drum=False, name="")
Instrument(program=25, is_drum=False, name="")
Instrument(program=25, is_drum=False, name="")
Instrument(program=54, is_drum=False, name="")
Instrument(program=27, is_drum=False, name="")
Instrument(program=74, is_drum=False, name="")
Instrument(program=0, is_drum=True, name="")
Instrument(program=119, is_drum=False, name="")
Instrument(program=89, is_drum=False, name="")
Instrument(program=49, is_drum=False, name="")
Instrument(program=127, is_drum=False, name="")
Instrument(program=27, is_drum=False, name="")
Instrument(program=124, is_drum=False, name="")


In [38]:
 midi_notes_count = sum(len(inst.notes) for inst in midi_obj.instruments)
 if midi_notes_count == 0:
   print('ERROR(BLANK): ' + file_name + '\n', end='')

### Obtained 8 tuple representation

In [39]:
# (0 Measure, 1 Pos, 2 Program, 3 Pitch, 4 Duration, 5 Velocity, 6 TimeSig, 7 Tempo)
# (Measure, TimeSig)
# (Pos, Tempo)
# Percussion: Program=128 Pitch=[128,255]

In [40]:
 e = MIDI_to_encoding(midi_obj)
 print(e)

[(2, 0, 25, 42, 36, 21, 9, 32), (2, 8, 25, 54, 20, 19, 9, 32), (2, 16, 25, 61, 20, 24, 9, 32), (2, 24, 25, 57, 24, 20, 9, 32), (2, 32, 25, 54, 20, 18, 9, 32), (2, 40, 25, 62, 8, 25, 9, 32), (2, 48, 25, 61, 8, 17, 9, 32), (2, 56, 25, 57, 8, 18, 9, 32), (3, 0, 25, 50, 20, 20, 9, 32), (3, 8, 25, 57, 16, 17, 9, 32), (3, 15, 68, 73, 8, 26, 9, 32), (3, 16, 25, 62, 8, 21, 9, 32), (3, 23, 68, 71, 16, 24, 9, 32), (3, 24, 25, 45, 28, 19, 9, 32), (3, 32, 25, 52, 20, 19, 9, 32), (3, 40, 25, 57, 8, 20, 9, 32), (3, 40, 68, 69, 20, 26, 9, 32), (3, 48, 25, 64, 8, 24, 9, 32), (3, 56, 25, 57, 8, 19, 9, 32), (4, 0, 25, 42, 36, 21, 9, 32), (4, 8, 25, 54, 20, 19, 9, 32), (4, 16, 25, 61, 20, 24, 9, 32), (4, 24, 25, 57, 24, 20, 9, 32), (4, 32, 25, 54, 20, 18, 9, 32), (4, 40, 25, 62, 8, 25, 9, 32), (4, 48, 25, 61, 8, 17, 9, 32), (4, 56, 25, 57, 8, 18, 9, 32), (5, 0, 25, 50, 20, 20, 9, 32), (5, 8, 25, 57, 16, 17, 9, 32), (5, 16, 25, 62, 8, 21, 9, 32), (5, 24, 25, 45, 28, 19, 9, 32), (5, 32, 25, 52, 20, 19, 9, 

In [41]:
 print(len(e))

5975


In [42]:
 output_str_list = []
 sample_step = max(round(sample_len_max / sample_overlap_rate), 1)

In [43]:
 for p in range(0 - random.randint(0, sample_len_max - 1), len(e), sample_step):
   L = max(p, 0)
   R = min(p + sample_len_max, len(e)) - 1
   bar_index_list = [e[i][0]
                     for i in range(L, R + 1) if e[i][0] is not None]
   bar_index_min = 0
   bar_index_max = 0
   if len(bar_index_list) > 0:
       bar_index_min = min(bar_index_list)
       bar_index_max = max(bar_index_list)
   offset_lower_bound = -bar_index_min
   offset_upper_bound = bar_max - 1 - bar_index_max
   # to make bar index distribute in [0, bar_max)
   bar_index_offset = random.randint(
       offset_lower_bound, offset_upper_bound) if offset_lower_bound <= offset_upper_bound else offset_lower_bound
   e_segment = []
   for i in e[L: R + 1]:
       if i[0] is None or i[0] + bar_index_offset < bar_max:
           e_segment.append(i)
       else:
           break
   tokens_per_note = 8
   output_words = (['<s>'] * tokens_per_note) \
       + [('<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) if k is not None else '<unk>') for i in e_segment for j, k in enumerate(i)] \
       + (['</s>'] * (tokens_per_note - 1)
           )  # tokens_per_note - 1 for append_eos functionality of binarizer in fairseq
   output_str_list.append(' '.join(output_words))

In [44]:
 #print(e)

In [45]:
 #print(output_str_list)

## Loading pretrained MusicBERT (derived from fairseq RoBERTa)

In [46]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
ruta_directorio_musicbert = "/app/work/TFG/muzic/musicbert"

if ruta_directorio_musicbert not in sys.path:
    sys.path.append(ruta_directorio_musicbert)

from fairseq.models.roberta import RobertaModel
# from fairseq.models.transformer import TransformerModel
import numpy as np
import torch
import torch.nn.functional as F
import sys
from musicbert import *


disable_cp = False
mask_strategy = ['bar']
convert_encoding = OCTMIDI
crop_length = None


  from scipy.sparse.base import spmatrix


In [47]:
def gen_dictionary(file_name):
    num = 0
    with open(file_name, 'w') as f:
        for j in range(bar_max):
            print('<0-{}>'.format(j), num, file=f)
        for j in range(beat_note_factor * max_notes_per_bar * pos_resolution):
            print('<1-{}>'.format(j), num, file=f)
        for j in range(max_inst + 1 + 1):
            # max_inst + 1 for percussion
            print('<2-{}>'.format(j), num, file=f)
        for j in range(2 * max_pitch + 1 + 1):
            # max_pitch + 1 ~ 2 * max_pitch + 1 for percussion
            print('<3-{}>'.format(j), num, file=f)
        for j in range(duration_max * pos_resolution):
            print('<4-{}>'.format(j), num, file=f)
        for j in range(v2e(max_velocity) + 1):
            print('<5-{}>'.format(j), num, file=f)
        for j in range(len(ts_list)):
            print('<6-{}>'.format(j), num, file=f)
        for j in range(b2e(max_tempo) + 1):
            print('<7-{}>'.format(j), num, file=f)


In [48]:

os.makedirs('input0', exist_ok=True)

ROOT = os.path.join(os.getcwd(), 'input0')
dict_file_path = os.path.join(ROOT, 'dict.txt')

gen_dictionary(dict_file_path)


In [49]:

os.makedirs('label', exist_ok=True)

ROOT = os.path.join(os.getcwd(), 'label')
dict_file_path = os.path.join(ROOT, 'dict.txt')

gen_dictionary(dict_file_path)


mentioned in https://github.com/microsoft/muzic/issues/51#issuecomment-1115821739_

https://fairseq.readthedocs.io/en/v0.10.2/_modules/fairseq/models.html 

https://fairseq.readthedocs.io/en/v0.10.2/command_line_tools.html

https://github.com/facebookresearch/fairseq/issues/2546

In [50]:
%cd /app/work/TFG/muzic/musicbert

/app/work/TFG/muzic/musicbert


`https://github.com/microsoft/muzic/blob/main/musicbert/musicbert/__init__.py`

In [51]:
print('loading model and data')

roberta_base = MusicBERTModel.from_pretrained('.', 
  checkpoint_file = 'checkpoints/checkpoint_last_musicbert_base_w_genre_head.pt',
  # user_dir='musicbert'    # activate the MusicBERT plugin with this keyword
)
#,
 # data_name_or_path = '.')

loading model and data


  return torch._C._cuda_getDeviceCount() > 0


In [52]:
print(roberta_base)

RobertaHubInterface(
  (model): MusicBERTModel(
    (encoder): MusicBERTEncoder(
      (sentence_encoder): OctupleEncoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(1237, 768, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
        (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (layers): ModuleList(
          (0): TransformerSentenceEncoderLayer(
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (

In [53]:
samp = roberta_base.model.encoder.sentence_encoder
print(samp)
del samp

OctupleEncoder(
  (dropout_module): FairseqDropout()
  (embed_tokens): Embedding(1237, 768, padding_idx=1)
  (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
  (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (layers): ModuleList(
    (0): TransformerSentenceEncoderLayer(
      (dropout_module): FairseqDropout()
      (activation_dropout_module): FairseqDropout()
      (self_attn): MultiheadAttention(
        (dropout_module): FairseqDropout()
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bia

In [54]:
samp = roberta_base.model.encoder.lm_head
print(samp)
del samp

RobertaLMHead(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)


In [55]:
roberta_base.cuda()
roberta_base.eval()

RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW

In [None]:
# print('loading model and data')

# roberta_base = RobertaModel.from_pretrained('.', 
#   checkpoint_file = 'checkpoints/checkpoint_last_musicbert_base_w_genre_head.pt',
#   user_dir='musicbert'    # activate the MusicBERT plugin with this keyword
# )
# #,
#  # data_name_or_path = '.')

In [None]:
# print('loading model and data')

# roberta_small = MusicBERTModel.from_pretrained('.', 
#   checkpoint_file = 'checkpoints/checkpoint_last_musicbert_small_w_genre_head.pt',
#   user_dir='musicbert'    # activate the MusicBERT plugin with this keyword
# )


#  PRUEBA DATASET GRANDE


## Definicion de funciones

In [None]:
import json
import concurrent.futures
from tqdm.auto import tqdm

In [None]:
# Esta funcion calcula la similitud entre dos vectores
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(vec1, vec2):
    # pasamos a array de 2D para que puede pasar al funcion consine_similarity
    vec1_2d = vec1.reshape(1, -1)
    vec2_2d = vec2.reshape(1, -1)
    
    similarity_matrix = cosine_similarity(vec1_2d, vec2_2d)
    
    similarity_score = similarity_matrix[0][0]
    
    return similarity_score


In [None]:
# Esta función extrae características de un archivo MIDI utilizando MusicBERT(RoBERTa)

def get_music_features(midi_path, MusicBERT, max_length=1024):
    # Sacar los tokens utilizando miditoolkit
    midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
    encoding = MIDI_to_encoding(midi_obj)
    encoding = shift_bar_to_front(encoding)
    octuple_midi_str = encoding_to_str(encoding)
    octuple_midi_tokenized = MusicBERT.task.label_dictionary.encode_line(octuple_midi_str)

    # cortar o rellenar la secuencia hasta max_len
    if len(octuple_midi_tokenized) > max_length:
        octuple_midi_tokenized = octuple_midi_tokenized[:max_length]
    else:
        octuple_midi_tokenized += [1] * (max_length - len(octuple_midi_tokenized))

    # Pasar tokens a lista
    if not isinstance(octuple_midi_tokenized, list):
        octuple_midi_tokenized = octuple_midi_tokenized.tolist()

    # Pasamos tokens para que MusicBERT proceda
    input_tensor = torch.tensor([octuple_midi_tokenized]).long().cpu()
    #input_tensor = torch.tensor([octuple_midi_tokenized]).long().cuda()
    with torch.no_grad():
        features = MusicBERT.extract_features(input_tensor)[0].mean(dim=1).cpu().numpy()

    return features


In [None]:
def recommend_songs(input_midi_path, midi_files, midi_features, model, X):
    # Proceder archivo MIDI de con la funcion get_music_features
    input_features = get_music_features(input_midi_path, model)
    # Calcula la similitud del archivo MIDI de entrada y el vector de todos las archivo MIDI en midi_features
    similarities = [calculate_similarity(input_features, features) for features in midi_features]
    # Ordenar segun similitud
    sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)

    recommended_songs = []
    song_count = 0
    idx = 0
    while song_count < X + 1:
        song_name, similarity = midi_files[sorted_indices[idx]], similarities[sorted_indices[idx]]
        # Verificar si la canción ya está en la lista de recomendaciones
        if song_name not in [song[0] for song in recommended_songs]:
            recommended_songs.append((song_name, similarity))
            song_count += 1
        idx += 1

    # Eliminar la primera canción, que es la canción de entrada
    recommended_songs = recommended_songs[1:]

    return recommended_songs


In [None]:
#funcion para cargar todos los archivos midis
import os

def get_midi_files(directory):
    midi_files = []
    for root, dirs, files in os.walk(directory):
        for filename in files:
            if filename.endswith('.mid'):
                midi_files.append(os.path.join(root, filename))
    return midi_files


In [None]:
# buscar el nombre que le corresponde dentro del archivo JSON
def get_new_midi_name(midi_path, json_data):
    midi_id = midi_path.split("/")[-1].split(".")[0]
    if midi_id in json_data:
        return json_data[midi_id][0]  # Elige el primer nombre en la lista
    else:
        return midi_path.split("/")[-1]

## Ejecucion

In [None]:
%cd /app/work/midi

In [None]:
#!wget http://hog.ee.columbia.edu/craffel/lmd/md5_to_paths.json

In [None]:
#!wget http://hog.ee.columbia.edu/craffel/lmd/lmd_matched.tar.gz

In [None]:
#!tar -xzvf lmd_matched.tar.gz

In [None]:
directory = '/app/work/midi'
midi_files = get_midi_files(directory)
print(len(midi_files))


In [None]:
import json

# cargamos el archivo JSON
json_file = "/app/work/midi/md5_to_paths.json" 
with open(json_file) as f:
    json_data = json.load(f)

In [None]:
## CELDA PARA PROCEDER DATASET
#import concurrent.futures
#from tqdm.auto import tqdm

#def process_midi_file(midi_file):
#    try:
#        features = get_music_features(midi_file, roberta_base)
#        return (midi_file, features)
#    except Exception as e:
#        print(f"Error al procesar el archivo MIDI: {e}")
#        return None

#midi_features = []
#processed_midi_files = []

#total_midi_files = len(midi_files)
#print(f"Total de archivos MIDI: {total_midi_files}")

#with concurrent.futures.ThreadPoolExecutor() as executor:
#    results = list(tqdm(executor.map(process_midi_file, midi_files), desc="Procesando archivos MIDI", total=len(midi_files)))

#for result in results:
#    if result is not None:
#        midi_file, features = result
#        processed_midi_files.append(midi_file)
#        midi_features.append(features)

In [None]:
import pickle

with open("midi_data.pkl", "rb") as f:
    loaded_midi_data = pickle.load(f)

In [None]:
import json

def get_new_midi_name(midi_path, json_data):
    midi_id = midi_path.split("/")[-1].split(".")[0]
    if midi_id in json_data:
        return json_data[midi_id][0]  # Elige el primer nombre en la lista
    else:
        return midi_path.split("/")[-1]

json_file = "md5_to_paths.json"
with open(json_file) as f:
    json_data = json.load(f)
    
updated_midi_data = {}
for midi_path, features in loaded_midi_data.items():
    updated_midi_name = get_new_midi_name(midi_path, json_data)
    updated_midi_data[updated_midi_name] = features


In [None]:
input_midi_path = "/app/work/midi/iwantitthatway.mid"
X = 3  # Número de recomendaciones que quieres

# Obtener las características de la canción de entrada
input_features = get_music_features(input_midi_path, roberta_base)

# Separar los nombres de las canciones y las características de los archivos MIDI cargados
loaded_midi_names = list(updated_midi_data.keys())
loaded_midi_features = list(updated_midi_data.values())

recommendations = recommend_songs(input_midi_path, loaded_midi_names, loaded_midi_features, roberta_base, X)

print(f"Top {X} recomendaciones para la canción seleccionada '{input_midi_path}':")
for idx, (midi_file, similarity) in enumerate(recommendations, 1):
    print(f"{idx}. {midi_file} (similarity: {similarity:.4f})")

In [None]:
#processed_midi_files = [get_new_midi_name(midi_path, json_data) for midi_path in midi_files]

#print(midi_features[:2])

In [None]:
#input_midi_path = "/content/midi/Michael_Jackson_-_Beat_It.mid"  

#cuantas recomendaciones quieres
#X = 3 

#recommendations = recommend_songs(input_midi_path, processed_midi_files, midi_features, roberta_base, X)

#print(f"Top {X} recomendadas para la canciones seleccionado '{input_midi_path}':")

#for idx, (midi_file, similarity) in enumerate(recommendations, 1):
#    print(f"{idx}. {midi_file} (similarity: {similarity:.4f})")


## Visualizacion grafica

In [None]:
def visualize_tsne(processed_midi_files, midi_features, perplexity=None):
    midi_features = np.array(midi_features)

    if perplexity is None:
        perplexity = min(30, len(processed_midi_files) - 1)

    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    embedded_features = tsne.fit_transform(midi_features)

    plt.figure(figsize=(12, 12))
    plt.scatter(embedded_features[:, 0], embedded_features[:, 1])
    for i, midi_file in enumerate(processed_midi_files):
        plt.annotate(midi_file, (embedded_features[i, 0], embedded_features[i, 1]))

    plt.title("Visualizacion grafica MIDI")
    plt.show()

visualize_tsne(processed_midi_files, midi_features)
