# Music Composition with Trained GPT-2 Model

In this Google Colab notebook, we'll load a pretrained GPT-2 model for music composition in ABC notation. We'll generate music compositions and evaluate them using the BLUE score and Levenshtein similarity metrics.


In [None]:
import torch
from tqdm import tqdm
from argparse import ArgumentParser

import glob
import os
import pandas as pd

import sys
!pip install wandb

import wandb
wandb.login(key='30b44f6f59b06faebb3d1f78df32c6fd9961f07d')
!{sys.executable} -m pip install youtokentome
!{sys.executable} -m pip install transformers
!pip install accelerate -U
from transformers import Trainer, TrainingArguments,default_data_collator
import youtokentome as yttm



In [None]:
ORIGIN = os.path.normpath(os.getcwd())
print(ORIGIN)
TRAIN_DIR ="/content/drive/MyDrive/test2/"
VALID_DIR = "/content/drive/MyDrive/Music_project/valid_path/"
TEST_DIR = "/content/drive/MyDrive/Music_project/test_path/"
TOKENIZER_DIR = "/content/drive/MyDrive/Music_project/abc_run5.yttm"
DATASET_DIR ="/content/drive/MyDrive/Music_project/300,000_new_samples.csv"
# OUTPUT_DIR = "/content/drive/MyDrive/Music_project/output_GPT2_checkpoints6"
OUTPUT_DIR = "/content/drive/MyDrive/Music_project/"


In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "gpt2"  # You can also use "gpt2-medium", "gpt2-large", etc., depending on the model size you want to use.
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

In [None]:
from transformers import GPT2LMHeadModel, GPT2Config

# Define paths to the model and config files
model_path = OUTPUT_DIR + 'output_GPT2_checkpoints6run9_withGPT2_300,000_new_samples/checkpoint-250000/pytorch_model.bin'
config_path = OUTPUT_DIR + 'output_GPT2_checkpoints6run9_withGPT2_300,000_new_samples/checkpoint-250000/config.json'

# Load the configuration from the config.json file
config = GPT2Config.from_json_file(config_path)

# Create an instance of the model using the loaded configuration
model = GPT2LMHeadModel(config)

# Load the model's weights from the .bin file
model.load_state_dict(torch.load(model_path, map_location="cpu"))


In [None]:
USEABLE_PARAMS = [i+":" for i in "BCDFGHIKLMmNOPQRrSsTUVWwXZ"] # These are the parameters for key

def read_abc(path):
    keys = []
    notes = []
    with open(path) as rf:
        for line in rf:
            line = line.strip()
            if line.startswith("%"): # Skip any commments
                continue

            if any([line.startswith(key) for key in USEABLE_PARAMS]):
                if(line.startswith('T')):
                    continue # skipping the title for better tokenization
                keys.append(line)
            else:
                notes.append(line)

    keys = " ".join(keys)

    notes = "".join(notes).strip()
    notes = notes.replace(" ", "")

    if notes.endswith("|"):
        notes = notes[:-1]
    # Remove unneeded character.
    notes = notes.replace(" \ ", "")
    notes = notes.replace("\\", "")
    notes = notes.replace("\ ", "")
    notes = notes.replace("x8|", "") # 8 because all of the midi file has a L:1/8 that means one muted bar
    notes = notes.replace("z8|", "") # 8 because all of the midi file has a L:1/8 that means one muted bar

    notes = notes.strip()
    notes = " ".join(notes.split(" "))

    if not keys or not notes:
        return None, None

    return keys, notes



In [None]:
def predict_notes(model, tokenizer, keys, notes):
    print(notes)
    keys_tokens = tokenizer.encode(keys)
    notes_tokens = tokenizer.encode(notes)

    # TODO fix max lenght of transformer
    if len(keys_tokens) + len(notes_tokens) > 510:
        notes_tokens = notes_tokens[len(notes_tokens) - len(keys_tokens) - 510:]

    context_tokens = [2] + keys_tokens + notes_tokens + [3]

    context_tokens = torch.tensor(context_tokens, dtype=torch.long).unsqueeze(0)



    bad_words_ids = []
    bad_words = ["x8 | "]
    for w in bad_words:
        bad_words_ids.append(tokenizer.encode(bad_words)[0])
        gen_tokens = model.generate(input_ids=context_tokens ,
                      max_length=40,
                      min_length=20,
                      num_beams=40, # num of memorizing sequences
                      pad_token_id=tokenizer.eos_token_id,
                      no_repeat_ngram_size=2, # allows to avoid melody repeating
                      )

#GPT2

    gen_tokens = gen_tokens[0].tolist()

    notes = tokenizer.decode(gen_tokens, ignore_ids=[0,1,2,3])[0]
    notes = notes.replace(" ", "").replace("|", "|\n")

    return notes

def predict(model, tokenizer, text_path, output_dir):

    keys, notes = read_abc(RUN5_DIR)
    print("Finished read_abc")
    print(keys)
    print(notes)
    return (notes)
    if notes is None:
        print("No notes found")
        return

    #keys, notes = read_abc(text_path)



    # Find the index of the 8th '|' character
    index = 0
    count = 0
    for i, char in enumerate(notes):
        if char == '|':
            count += 1
            if count == 8:
                index = i
                break

# Remove the text after the 8th '|' character
    orignal_notes = notes[index:]

    notes = notes[:index+1]

    new_path = text_path

    predicted_tokens = predict_notes(model, tokenizer, keys, notes)

    return [predicted_tokens, orignal_notes]

In [None]:
import os
import shutil

file_path = '/content/181854_118815.abc'
# shutil.copy(file_path, file_path)

if os.path.exists(file_path):
    # File exists, proceed with reading
    print("---")
else:
    print("File does not exist at the specified path.")


In [None]:

abc_notation = "D4 A4| \A2 DA ^A=A DA|C4 A4|A2 CA ^A=A CA|^A,4 =A4|A2 ^A,=A ^A=A ^A,=A|"
abc_notation_281804_59008 = "A,3[A4-F4-][A/2F/2]x/2|[B2-G2-D2-B,2-][B/2-G/2-D/2-B,/2][B/2G/2D/2]c<BGDD|DE6-E/2x/2|x2AE2A3|[A2-F2-C2-A,2-][A/2-F/2-C/2-A,/2][A/2F/2C/2](3C2A2F2C|[GD]DBG2DBD|E2^GEeE^GE|"
abc_notation_tamny_3aleek = "E/2F/2=G2F/2E/2 =D4|=C/2=D/2E2D/2C/2 B,4|E/2F/2=G2F/2E/2 B4|=C/2=D/2E2D/2C/2 F4"
abc_notation_282689_76955 = "X:1M:4/4L:1/8Q:1/4=120K:C%0sharpsV:1[_A-B,][A-F]A-[d2A2-]A-[A-F]A-|[A-A,][A-E]A-[c2A2-]A-[A-E]A-|[A-B,][A-F]A-[d-A]dAFA|A,EAc2AEA|B,FAd2AFA|A,EAc2AEA|EBegfeBA|G[geA]FE[geB,]CEF|B,F[d'aA]d2[d'aA]FA|"
abc_notation = "[c4-G4-E4-][c/2G/2E/2-]E/2E-[G/2-E/2-][G/2-E/2-C/2-][GECG,]| \
[g2-e2-B2G2][g/2e/2]x/2[geBG][a2-f2-d2A2][a/2f/2]x/2[afdA]|\
[b2-g2-d2B2][b/2g/2]x/2[b/2g/2d/2B/2]x/2[b2-g2-d2B2-][b/2g/2B/2]x3/2|\
x3[b/2g/2d/2B/2]x/2[bgdB]x[b3/2g3/2d3/2B3/2]x/2|\
[g2-e2-B2G2][g/2e/2]x/2[g4-e4-B4-G4-][g-e-B-G-]|\
[g2e2B2G2]x[g/2e/2B/2G/2]x/2[geBG]x[g3/2e3/2B3/2G3/2]x/2|\
[b2-g2-d2B2][b/2g/2]x/2[b4-g4-d4-B4-][b-g-d-B-]|"

abc_notation_test_9596 = "[E2-C2-G,2-C,2-C,,2-][E/2-C/2-G,/2C,/2-C,,/2-][E/2C/2C,/2-C,,/2][E/2-C/2G,/2-C,/2][E/2G,/2-][D2-B,2-G,2-D,2-G,,2-][D/2B,/2-G,/2D,/2-G,,/2-][B,/2D,/2G,,/2-][DB,G,G,,]| \
[E2C2A,2A,,2-][E/2-C/2-A,/2-A,,/2][E/2-C/2-A,/2-][E/2-C/2-A,/2A,,/2-][E/2C/2A,,/2][E2B,2G,2E,,2-][E/2-B,/2-G,/2-E,,/2][E/2B,/2G,/2]E,,|"
abc_notation_215688_124450 = "x2[DB,F,]x2B,-[DB,]F-|[f2-d2-B2-F2][f2-d2-B2-][f/2d/2-B/2-][d/2B/2]A,-[=B,A,]E-|[e2-=B2-A2-E2-][e/2-=B/2-A/2-E/2][e=B-A-][=B/2A/2]xE,-[G,-E,][=A,G,]|[c4-=A4-G4-E4-C4-][c=A-G-E-C][=A/2G/2E/2D,/2-]D,/2-[G,D,-][_A,/2-D,/2]A,/2-|[A2-G2-D2-A,2][A3/2G3/2D3/2]x3/2_B,-[D-B,-][F-D-B,-]|"
#E/2F/2=G2F/2E/2 B4|=C/2=D/2E2D/2C/2 F4|E/2F/2=G2F/2E/2 GF/2E/2 GF/2E/2|FF2E/2=D/2 D2- D/2E/2D/2=C/2| \=C/2=D/2E2D/2C/2 ED/2C/2 ED/2C/2|FF2E/2=D/2 D4|E/2F/2=G2F/2E/2 GF/2E/2 GF/2E/2|FF2E/2=D/2 D2- D/2E/2D/2=C/2|=C/2=D/2E2D/2C/2 ED/2C/2 ED/2C/2|FF2E/2=D/2 D2- D/2E/2D/2=C/2|B,B, =C=D2<E2E|E=D3/2=CB,/2 E3D/2C/2|B,B, =C=D2<F2F/2F/2|F3/2=D=CB,/2 F3=G/2F/2|


abc_notation_234675_21548 = "[c'4A,4-]A,-[fA,][aE,-][bE,]|[c'2C2-F,2-][fC-F,-][a-CF,][aG,-][c'2G,2-][b-G,]|[bF,-][aF,-][gF,-][f-F,][f2F,2][B,-F,-][eB,F,]|[f3/2A,3/2-D,3/2-][c'3/2A,3/2-D,3/2-][b-A,D,][b2G,2-C,2-][G,-C,-][eG,C,]|[f3/2A,3/2-D,3/2-][c'3/2A,3/2-D,3/2-][b-A,D,][bB,-E,-][aB,-E,-][gB,-E,-][f-B,E,]|[f8-A,8-D,8-]|[f4A,4D,4][A,2D,2][A,2E,2]|[E8-C8-A,8-]|"
abc_notation_106429_180318 = "[E/2-C/2-G,/2][E/2C/2-]C/2G,/2[E/2-C/2-G,/2][E/2C/2-]C/2-[C/2G,/2][EC-A,]C/2A,/2[E/2-C/2-A,/2][EC]A,/2|[D/2-A,/2-F,/2][DA,]F,/2[D/2-A,/2-F,/2][D/2-A,/2]D/2F,/2[D/2-=B,/2-G,/2][D=B,]G,/2[D/2-=B,/2-G,/2][D=B,]G,/2|[E/2-C/2-G,/2][E/2C/2-]C/2G,/2[E/2-C/2-G,/2][E/2C/2-D,/2-][C/2D,/2][G,/2E,/2][F/2-C/2-A,/2F,/2-][F/2C/2-F,/2-][C/2F,/2-][A,/2F,/2-"

abc_notation_191238_36222 ="F,,,-[B/2F/2C/2F,,,/2]x/2[B/2F/2C/2F,,,/2-]F,,,/2-[B/2F/2C/2F,,,/2]x/2[B-F-C][B/2F/2]x/2[AFC-]C/2x/2|[F-DB,F,,,-][F/2F,,,/2]x/2[FDB,F,,,-]F,,,/2x/2[F-D-B,][F/2-D/2C,,/2][F/2D,,/2][G/2-E/2-C/2-F,,/2-][G/2-E/2-C/2F,,/2D,,/2][G/2E/2C,,/2]A,,,/2|F,,,-[B/2F/2C/2F,,,/2]x/2[B/2F/2C/2F,,,/2-]F,,,/2-[B/2F/2C/2F,,,/2]x/2[B-F-C][B/2F/2]x/2[A-FC-][A/2C/2]x/2|"
abc_notation_221394_134131 = "A,,2-[e4A4E4A,,4]E,2-|[a2e2A2E,2]A,-[a3-e3-A3-A,3-][a/2e/2-A/2-A,/2E,/2-][e/2A/2-E,/2-][A/2E,/2-]E,/2|E,,2-[e3-B3-E3-E,,3-][e/2-B/2E/2E,,/2-][e/2E,,/2]B,,2-|[g3/2-e3/2-B3/2B,,3/2-][g/2e/2B,,/2]E,-[g-e-B-E,-][g/2-e/2-B/2-E,/2F,,/2-][g3/2e3/2B3/2-F,,3/2][B/2G,,/2-]G,,3/2|A,,2-[e4A4E4A,,4]E,2-|[a3/2e3/2-A3/2E,3/2-][e/2E,/2]A,-[a3-e3-A3-A,3][aeA-E,-][A/2E,/2-]E,/2|A,,2-[e4A4E4A,,4]E,2-|[a2e2A2E,2-][A,/2-E,/2]A,/2-[a3-e3-A3-A,3-][a/2-e/2-A/2-A,/2E,/2-][aeA-E,-][A/2E,/2]|"
input_length = len(abc_notation_221394_134131)
print(input_length)

input_ids = tokenizer.encode(abc_notation_221394_134131, return_tensors="pt")
output = model.generate(input_ids=input_ids ,
                            max_length=600,
                            min_length=550,
                            num_beams=50, # num of memorizing sequences
                            pad_token_id=tokenizer.eos_token_id,
                            no_repeat_ngram_size=10, # allows to avoid melody repeating
                            temperature = 0.5,
                           )
generated_abc = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_abc)
print(len(generated_abc))

In [None]:
RUN5_DIR = "/content/181854_118815.abc"
p = "/content/drive/MyDrive/Run_5/"
output_dir = "/content/drive/MyDrive/Run_5/"
print("Starts generation")
#

results = predict(model, tokenizer, p, output_dir)
print(results)
print(results[0])
print("========")
print(results[1])
