# Pocket MIDI Generator Training Dataset Maker

# Установка окружения



In [None]:
#@title Установка всех зависимостей (запускается один раз)

!git clone https://github.com/asigalov61/tegridy-tools
!pip install tqdm

In [None]:
#@title Установка всех библиотек

print('Loading needed modules. Please wait...')
import os

import math
import statistics
import random

from tqdm import tqdm

if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

print('Loading TMIDIX module...')
os.chdir('/content/tegridy-tools/tegridy-tools')

import TMIDIX

print('Done!')

os.chdir('/content/')
print('Done')

# Установка датасета

In [None]:
#@title Установка LAKH MIDI Dataset

%cd /content/Dataset/

!wget 'http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz'
!tar -xvf 'lmd_full.tar.gz'
!rm 'lmd_full.tar.gz'

%cd /content/

In [None]:
#@title Подключение к Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Подготовка датасета

In [None]:
#@title Сохранение датасета
###########

print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/content/Dataset"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if filez == []:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

print('Randomizing file list...')
random.shuffle(filez)

TMIDIX.Tegridy_Any_Pickle_File_Writer(filez, '/content/drive/MyDrive/filez')

In [None]:
#@title Загрузка данных с диска
filez = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/drive/MyDrive/filez')

# Обработка данных

In [None]:
print('Starting up...')
print('=' * 70)

START_FILE_NUMBER = 25000
LAST_SAVED_BATCH_COUNT = 0

input_files_count = START_FILE_NUMBER
files_count = LAST_SAVED_BATCH_COUNT

melody_chords_f = []

stats = [0] * 16

print('Processing MIDI files. Please wait...')
print('=' * 70)

for f in tqdm(filez[START_FILE_NUMBER:]):
    try:
        input_files_count += 1
        fn = os.path.basename(f)
        file_size = os.path.getsize(f)

        if file_size < 250000:
            score = TMIDIX.midi2ms_score(open(f, 'rb').read())

            events_matrix = []
            itrack = 1
            patches = [0] * 16

            patch_map = [
                [0,1,2,3,4,5,6,7],     # Piano
                [24,25,26,27,28,29,30],# Guitar
                [32,33,34,35,36,37,38,39], # Bass
                [40,41],               # Violin
                [42,43],               # Cello
                [46],                  # Harp
                [56,57,58,59,60],      # Trumpet
                [64,65,66,67,68,69,70,71], # Sax
                [72,73,74,75,76,77,78],# Flute
                [-1],                  # Drums
                [52,53],               # Choir
                [16,17,18,19,20],      # Organ
            ]

            while itrack < len(score):
                for event in score[itrack]:
                    if event[0] == 'note' or event[0] == 'patch_change':
                        events_matrix.append(event)
                itrack += 1

            events_matrix.sort(key=lambda x: x[1])
            events_matrix1 = []

            for event in events_matrix:
                if event[0] == 'patch_change':
                    patches[event[2]] = event[3]

                if event[0] == 'note':
                    event.extend([patches[event[3]]])
                    once = False
                    for p in patch_map:
                        if event[6] in p and event[3] != 9:
                            event[3] = patch_map.index(p)
                            once = True
                            break
                    if not once and event[3] != 9:
                        event[3] = 15
                    event[5] = max(80, event[5])
                    if event[3] < 12:
                        events_matrix1.append(event)

            if len(events_matrix1) > 0:
                instruments = list(set([e[3] for e in events_matrix1 if e[3] != 9]))
                if len(instruments) == 1:
                    main_instrument = instruments[0]

                    for e in events_matrix1:
                        e[1] = int(e[1] / 8)
                        e[2] = int(e[2] / 32)

                    events_matrix1.sort(key=lambda x: x[4], reverse=True)
                    events_matrix1.sort(key=lambda x: x[1])

                    melody_chords = []
                    pe = events_matrix1[0]

                    for e in events_matrix1:
                        time = max(0, min(255, e[1] - pe[1]))
                        dur = max(1, min(127, e[2]))
                        cha = max(0, min(11, e[3]))
                        ptc = max(1, min(127, e[4]))
                        vel = max(8, min(127, e[5]))
                        velocity = round(vel / 15) - 1
                        melody_chords.append([time, dur, cha, ptc, velocity])
                        pe = e

                    if len([y for y in melody_chords if y[2] != 9]) > 12:
                        times = [y[0] for y in melody_chords[12:]]
                        avg_time = sum(times) / len(times) if times else 0
                        try:
                            mode_dur = statistics.mode([y[1] for y in melody_chords if y[2] != 9])
                        except:
                            mode_dur = 1
                        times_list = list(set(times))
                        num_chords = len([y for y in melody_chords if y[0] != 0])

                        if avg_time < 64 and mode_dur < 64 and 0 in times_list and 600 < num_chords < 256 * 50:

                            melody_chords_f.extend([
                                3087,
                                3074 if 9 in [y[2] for y in melody_chords] else 3073,
                                3075 + main_instrument,
                                2816
                            ])
                            if melody_chords[0][0] == 0:
                                melody_chords_f.append(0)

                            chords_count = 0
                            for m in melody_chords:
                                time = m[0]
                                dur_vel = (m[1] * 8) + m[4]
                                cha_ptc = (m[2] * 128) + m[3]

                                if (((num_chords // 50) * 50) - chords_count == 50) and time != 0:
                                    melody_chords_f.append(3072)
                                if chords_count % 50 == 0 and chords_count != 0 and time != 0:
                                    melody_chords_f.append(2816 + min(255, (chords_count // 50)))

                                if time != 0:
                                    melody_chords_f.extend([time, dur_vel + 256, cha_ptc + 1280])
                                    chords_count += 1
                                else:
                                    melody_chords_f.extend([dur_vel + 256, cha_ptc + 1280])

                                stats[m[2]] += 1

                            files_count += 1

                            if files_count % 5000 == 0:
                                print('SAVING !!!')
                                count = str(files_count)
                                TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, f'/content/drive/MyDrive/LAKH_INTs_{count}')
                                melody_chords_f = []

    except KeyboardInterrupt:
        print('Saving current progress and quitting...')
        break

    except Exception as ex:
        print('WARNING !!!')
        print('Bad MIDI:', f)
        print('Error detected:', ex)
        continue

# Final save
print('SAVING !!!')
count = str(files_count)
TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, f'/content/drive/MyDrive/LAKH_INTs_{count}')

print('Done!')
print('Total good processed MIDI files:', files_count)
print('Instruments stats:')
instruments_names = ['Piano', 'Guitar', 'Bass', 'Violin', 'Cello', 'Harp', 'Trumpet', 'Sax', 'Flute', 'Drums', 'Choir', 'Organ']
for i, name in enumerate(instruments_names):
    print(f'{name}:', stats[i])
