# Giant Music Transformer Training Dataset Maker (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

#### Project Los Angeles

#### Tegridy Code 2023

***

# (SETUP ENVIRONMENT)

In [None]:
#@title Install all dependencies (run only once per session)

!git clone https://github.com/asigalov61/tegridy-tools
!pip install tqdm

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os
import copy
import math
import statistics
import random

from joblib import Parallel, delayed, parallel_config

from tqdm import tqdm

if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

print('Loading TMIDIX module...')
os.chdir('/content/tegridy-tools/tegridy-tools')

import TMIDIX

print('Done!')

os.chdir('/content/')
print('Enjoy! :)')

# (DOWNLOAD SOURCE MIDI DATASET)

In [None]:
#@title Download original LAKH MIDI Dataset

%cd /content/Dataset/

!wget 'http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz'
!tar -xvf 'lmd_full.tar.gz'
!rm 'lmd_full.tar.gz'

%cd /content/

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# (FILE LIST)

In [None]:
#@title Save file list
###########

print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/content/Dataset"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if filez == []:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

print('Randomizing file list...')
random.shuffle(filez)

TMIDIX.Tegridy_Any_Pickle_File_Writer(filez, '/content/drive/MyDrive/filez')

In [None]:
#@title Load file list
filez = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/drive/MyDrive/filez')

# (LOAD TMIDIX MIDI PROCESSOR)

In [None]:
# @title Load TMIDIX MIDI Processor

def TMIDIX_MIDI_Processor(midi_file):

    melody_chords = []
    melody_chords_aug = []

    try:

        fn = os.path.basename(midi_file)

        # Filtering out GIANT4 MIDIs
        file_size = os.path.getsize(midi_file)

        if file_size <= 1000000:

          #=======================================================
          # START PROCESSING

          # Convering MIDI to ms score with MIDI.py module
          score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False)

          # INSTRUMENTS CONVERSION CYCLE
          events_matrix = []
          itrack = 1
          patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

          while itrack < len(score):
              for event in score[itrack]:
                  if event[0] == 'note' or event[0] == 'patch_change':
                      events_matrix.append(event)
              itrack += 1

          events_matrix.sort(key=lambda x: x[1])

          events_matrix1 = []

          for event in events_matrix:
                  if event[0] == 'patch_change':
                        patches[event[2]] = event[3]

                  if event[0] == 'note':
                        event.extend([patches[event[3]]])

                        if events_matrix1:
                            if (event[1] == events_matrix1[-1][1]):
                                if ([event[3], event[4]] != events_matrix1[-1][3:5]):
                                    events_matrix1.append(event)
                            else:
                                events_matrix1.append(event)

                        else:
                            events_matrix1.append(event)

        if len(events_matrix1) > 0:
            if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0:

                #=======================================================
                # PRE-PROCESSING

                # checking number of instruments in a composition
                instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9]))
                instruments_list = list(set([y[3] for y in events_matrix1]))

                if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0:

                    #======================================

                    events_matrix2 = []

                    # Recalculating timings
                    for e in events_matrix1:

                        # Original timings
                        e[1] = int(e[1] / 16)
                        e[2] = int(e[2] / 16)

                    #===================================
                    # ORIGINAL COMPOSITION
                    #===================================

                    # Sorting by patch, pitch, then by start-time

                    events_matrix1.sort(key=lambda x: x[6])
                    events_matrix1.sort(key=lambda x: x[4], reverse=True)
                    events_matrix1.sort(key=lambda x: x[1])

                    #=======================================================
                    # FINAL PROCESSING

                    melody_chords = []

                    # Break between compositions / Intro seq

                    if 9 in instruments_list:
                        drums_present = 19331 # Yes
                    else:
                        drums_present = 19330 # No

                    if events_matrix1[0][3] != 9:
                        pat = events_matrix1[0][6]
                    else:
                        pat = 128

                    melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq

                    #=======================================================
                    # MAIN PROCESSING CYCLE
                    #=======================================================

                    abs_time = 0

                    pbar_time = 0

                    pe = events_matrix1[0]

                    chords_counter = 1

                    comp_chords_len = len(list(set([y[1] for y in events_matrix1])))

                    for e in events_matrix1:

                        #=======================================================
                        # Timings...

                        # Cliping all values...
                        delta_time = max(0, min(255, e[1]-pe[1]))

                        # Durations and channels

                        dur = max(0, min(255, e[2]))
                        cha = max(0, min(15, e[3]))

                        # Patches
                        if cha == 9: # Drums patch will be == 128
                            pat = 128

                        else:
                            pat = e[6]

                        # Pitches

                        ptc = max(1, min(127, e[4]))

                        # Velocities

                        # Calculating octo-velocity
                        vel = max(8, min(127, e[5]))
                        velocity = round(vel / 15)-1

                        #=======================================================
                        # Outro seq

                        if ((comp_chords_len - chords_counter) == 50) and (delta_time != 0):
                            out_t = 18946+delta_time
                            out_p = 19202+ptc
                            melody_chords.extend([18945, out_t, out_p]) # outro seq

                        if delta_time != 0:
                            chords_counter += 1

                        #=======================================================
                        # FINAL NOTE SEQ

                        # Writing final note asynchronously

                        dur_vel = (8 * dur) + velocity
                        pat_ptc = (129 * pat) + ptc

                        melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])

                        pe = e

                        #=======================================================

                    melody_chords.extend([19462, 19462, 19462]) # EOS

                    #=======================================================

                    # TOTAL DICTIONARY SIZE 19462+1=19463

                    #=======================================================

                    return melody_chords

    except:
        return None

# (PROCESS)

In [None]:
#@title Process MIDIs with TMIDIX MIDI processor

NUMBER_OF_PARALLEL_JOBS = 4 # Number of parallel jobs
NUMBER_OF_FILES_PER_ITERATION = 16 # Number of files to queue for each parallel iteration
SAVE_EVERY_NUMBER_OF_ITERATIONS = 160 # Save every 2560 files

print('=' * 70)
print('TMIDIX MIDI Processor')
print('=' * 70)
print('Starting up...')
print('=' * 70)

###########

melody_chords_f = []

files_count = 0

print('Processing MIDI files. Please wait...')
print('=' * 70)

for i in tqdm(range(0, len(filez), NUMBER_OF_FILES_PER_ITERATION)):

  with parallel_config(backend='threading', n_jobs=NUMBER_OF_PARALLEL_JOBS, verbose = 0):

    output = Parallel(n_jobs=NUMBER_OF_PARALLEL_JOBS, verbose=0)(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+NUMBER_OF_FILES_PER_ITERATION])

    for o in output:

        if o is not None:
            melody_chords_f.append(o)

    files_count += len(melody_chords_f)

    # Saving every 2560 processed files
    if i % (NUMBER_OF_FILES_PER_ITERATION * SAVE_EVERY_NUMBER_OF_ITERATIONS) == 0 and i != 0:
        print('SAVING !!!')
        print('=' * 70)
        print('Saving processed files...')
        print('=' * 70)
        print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0]))
        print('=' * 70)
        print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')
        print('=' * 70)
        count = str(files_count)
        TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count)
        melody_chords_f = []

        print('=' * 70)

print('SAVING !!!')
print('=' * 70)
print('Saving processed files...')
print('=' * 70)
print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0]))
print('=' * 70)
print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')
print('=' * 70)
count = str(files_count)
TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count)

print('=' * 70)

# (TEST INTS)

In [None]:
#@title Test INTs

train_data1 = random.choice(melody_chords_f)

print('Sample INTs', train_data1[:15])

out = train_data1

if len(out) != 0:

    song = out
    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    patches = [-1] * 16

    channels = [0] * 16
    channels[9] = 1

    for ss in song:

        if 0 <= ss < 256:

            time += ss * 16

        if 256 <= ss < 2304:

            dur = ((ss-256) // 8) * 16
            vel = (((ss-256) % 8)+1) * 15

        if 2304 <= ss < 18945:

            patch = (ss-2304) // 129

            if patch < 128:

                if patch not in patches:
                  if 0 in channels:
                      cha = channels.index(0)
                      channels[cha] = 1
                  else:
                      cha = 15

                  patches[cha] = patch
                  channel = patches.index(patch)
                else:
                  channel = patches.index(patch)

            if patch == 128:
                channel = 9

            pitch = (ss-2304) % 129

            song_f.append(['note', time, dur, channel, pitch, vel ])

patches = [0 if x==-1 else x for x in patches]

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Giant Music Transformer',
                                                          output_file_name = '/content/Giant-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches
                                                          )

print('Done!')

# Congrats! You did it! :)