In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!sudo update-alternatives --config python3 # python 3.8

!sudo apt install python3-pip && python -m pip install --upgrade pip
!pip install "numpy < 1.24" scipy matplotlib ninja
!pip install --upgrade setuptools ez_setup
!pip install miditoolkit jedi cvxpy tensorboardX==2.2 triton wandb fairseq==0.10.2 torch==1.7.1 protobuf==3.20.*

In [None]:
%cd /content/gdrive/MyDrive/Licenta/utils/MidiProcessor/
!pip install . -v

# Melody Extraction

In [None]:
import numpy as np
import pretty_midi
import math
import os
import tqdm

tol = 1e-4
chord_threshold = 20
rhythm_threshold = 20

midi_folder_path = '/content/midis'
bad_midis = []

filez = list()
for (dirpath, dirnames, filenames) in os.walk(midi_folder_path):
    filez += [os.path.join(dirpath, file) for file in filenames]

for midi_name in filez:
  try:
    # print(midi_name)
    pm = pretty_midi.PrettyMIDI(midi_name)
    new_midi = pretty_midi.PrettyMIDI(initial_tempo = pm.estimate_tempo())

    for instrument in pm.instruments:
        instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
        # print('Instrument: {}'.format(instrument_name))

        # Skip drum instruments
        if instrument.is_drum:
          # print("Is drum. Skipping.")
          continue

        if 'bass' in instrument_name.lower():
          # print("Is bass. Skipping.")
          continue

        rhythm_notes_num = 0
        for note in instrument.notes:
          # Every note detected is one octave lower for some reason
          if note.pitch <= pretty_midi.note_name_to_number('A2'):
              rhythm_notes_num += 1
              if rhythm_notes_num >= rhythm_threshold:
                break
        if rhythm_notes_num >= rhythm_threshold:
          # print("Channel most likely a rhythm channel. Skipping")
          continue

        chords_found = 0
        i = 0
        while i < len(instrument.notes):
          chord_notes = 1
          j = i + 1
          while j < len(instrument.notes):
            if math.isclose(instrument.notes[i].start, instrument.notes[j].start, rel_tol=tol):
              chord_notes += 1
            j += 1

          if chord_notes >= 3:
            chords_found += 1
            if chords_found >= chord_threshold:
              break
          i += chord_notes

        if chords_found >= chord_threshold:
          # print("Is harmony based channel. Skipping")
          continue
        
        new_midi.instruments.append(instrument)

    lead_instrument = new_midi.instruments[0]
    for inst in new_midi.instruments[1:]:
      lead_instrument.notes.extend(inst.notes)
      lead_instrument.pitch_bends.extend(inst.pitch_bends)
      lead_instrument.control_changes.extend(inst.control_changes)

    new_midi.instruments = [new_midi.instruments[0]]
    midi_name = midi_name.split('/')
    new_midi.write('/content/new_midis/' + "processed_" + midi_name[3])
  except:
          bad_midis.append(midi_name[3])
          continue

print("Bad midis: ", len(bad_midis))

# Preprocess

In [None]:
%cd /content/gdrive/MyDrive/Licenta/museformer/

In [None]:
!unrar x "/content/augmented.rar" "/content/gdrive/MyDrive/Licenta/museformer/data/midi"

In [None]:
# Clean the dataset a bit

import os

midi_folder_path = "/content/gdrive/MyDrive/Licenta/museformer/data/midi"
filez = list()
for (dirpath, dirnames, filenames) in os.walk(midi_folder_path):
    filez += [os.path.join(dirpath, file) for file in filenames]

for midi in filez:
  if "MID" in midi:
    new_name = midi.replace("MID", "mid")
    os.rename(midi, new_name)

In [None]:
# Remove normalization when using augmented dataset
!mp-batch-encoding '/content/tokens2' '/content/augmented_tokens_middle_2' --encoding-method REMIGEN --remove-empty-bars --sort-insts id

In [None]:
# Prepare txt files for split
import os 
from sklearn.model_selection import train_test_split

names = os.listdir('./data/midi/')
x_main, x_test = train_test_split(names, test_size=208)
x_train, x_valid = train_test_split(x_main, test_size=208)

with open("./data/meta/train.txt", 'w') as f:
    for line in x_train:
        f.write(f"{line}\n")

with open("./data/meta/test.txt", 'w') as f:
    for line in x_test:
        f.write(f"{line}\n")
        
with open("./data/meta/valid.txt", 'w') as f:
    for line in x_valid:
        f.write(f"{line}\n")

In [None]:
!for split in train valid test; do python tools/generate_token_data_by_file_list.py data/meta/${split}.txt data/token data/split; done

In [None]:
!mkdir -p data-bin

!fairseq-preprocess \
  --only-source \
  --trainpref data/split/train.data \
  --validpref data/split/valid.data \
  --testpref data/split/test.data \
  --destdir data-bin/lmd6remi \
  --srcdict data/meta/general_use_dict.txt

# Train and Evaluate

In [None]:
%cd /content/gdrive/MyDrive/Licenta/museformer/
!bash ttrain/mf-lmd6remi-1.sh

In [None]:
!bash tval/val__mf-lmd6remi-x.sh 1 checkpoint_best.pt 10240

## WandB logging

In [1]:
!pip install wandb
!wandb login --relogin

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.15.4-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.25.0-py2.py3-none-any.whl (206 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.5/206.5 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools (from wandb)
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting

In [2]:
import wandb
import sys


log = open("/content/non_augmented_log.log")

wandb.init(
    project="Museformer-lauta",
    config={
        "learning_rate": 5e-4,
        "architecture": "Museformer GPT",
        "dataset": "lauta",
        "epochs": 4,  # change this when needed
    }
)

for line in log:
    if 'valid on' in line:
        line = line.split("|")
        epoch_num = line[3][-4:-1]
        loss = float(line[5][6:-1])
        wandb.log({"validation loss": loss})

log.seek(0)

for line in log:
    if 'train_inner' not in line:
        continue
    line = line[line.find("epoch"):]
    line = line.split()
    epochNum = line[1][0:-1]
    loss = line[5][5:-1]
    if loss == 'None':
        continue
    loss = float(loss)
    wandb.log({"training loss": loss})

[34m[1mwandb[0m: Currently logged in as: [33mrazvan-gabriel-budaca[0m ([33mlicenta_razvan[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Music generation


In [None]:
%cd /content/gdrive/MyDrive/Licenta/museformer/
!mkdir -p output_log
!printf '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n' | bash tgen/generation__mf-lmd6remi-x.sh 1 checkpoint_211_350070.pt | tee output_log/generation.log

In [None]:
!python tools/batch_extract_log.py output_log/generation.log output/generation --start_idx 1

In [None]:
!python tools/batch_generate_midis.py --encoding-method REMIGEN2 --input-dir output/generation --output-dir output/generation