In [1]:
from model_functions import predict, write_midi, get_token_flags
import json
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, TrainingArguments, Trainer

PATH_VOCAB = "../0_data/5_vocabs"
PATH_DATA = "../0_data/6_word_data"
PATH_MODELS = "../0_data/7_models"
PATH_MODELS_CONFIG = "../0_data/7_models/config"
PATH_PRED = "../0_data/8_predictions"
PATH_TOKENS = "../0_data/8_predictions/tokens"
PATH_MIDI = "../0_data/8_predictions/midi"

In [2]:
# define vocabulary configs
vocab_configs = {
    "a1" : {
        "pitch_range": 128,
        "duration_steps": 64,
        "triole_tokens": False,
    },
    "a2" : {
        "pitch_range": 128,
        "duration_steps": 64,
        "triole_tokens": True,
    },
    "a3" : {
        "pitch_range": 128,
        "duration_steps": 32,
        "triole_tokens": False,
    },
    "b" : {
        "pitch_range": 128,
        "duration_steps": 64,
        "triole_tokens": False,
    },
    "c" : {
        "pitch_range": 36,
        "duration_steps": 64,
        "triole_tokens": False,
    },
    "d" : {
        "pitch_range": 36,
        "duration_steps": 32,
        "triole_tokens": True,
    }
}

In [3]:
# make prompt predctions and calculate correct note rate

TICKS_PER_BEAT = 1024
TICKS_PER_MIN_DURATION = TICKS_PER_BEAT*4/32

model_name = "d"
if not os.path.exists(f"{PATH_MIDI}/{model_name}_prompts"):
        os.makedirs(f"{PATH_MIDI}/{model_name}_prompts")

# get token flags and duration bins
token_flags = get_token_flags(vocab_configs[model_name])
duration_steps = vocab_configs[model_name]["duration_steps"]
duration_bins = np.arange(TICKS_PER_MIN_DURATION, (TICKS_PER_MIN_DURATION*duration_steps)+1, TICKS_PER_MIN_DURATION, dtype=int)

# create tokenizer
tokenizer = GPT2Tokenizer(
        vocab_file=f"{PATH_VOCAB}/vocab_{model_name}.json", 
        merges_file=f"{PATH_VOCAB}/merges.txt")
tokenizer.add_special_tokens({'pad_token': 'PAD', 'bos_token': 'BOS', 'eos_token': 'EOS',})

# get vocabulary
with open(f"{PATH_VOCAB}/vocab_{model_name}.json", "r") as fp:
        vocab = json.load(fp)
token2word = {token: word for word, token in vocab.items()}

# load model and prompt data
model = GPT2LMHeadModel.from_pretrained(f"{PATH_MODELS_CONFIG}/{model_name}/end_version")
with open(f"{PATH_DATA}/prompt_data.json", "r") as fp:
        prompt_data = json.load(fp)

# make predictions save
data_generated = {}
for prompt in prompt_data.keys():
    output = predict(model, tokenizer, prompt=" ".join(prompt_data[prompt]), samples=5, max_length=1024)
    data_generated[prompt] = output
    for i, pred in enumerate(output):
        write_midi(output[i], token2word, duration_bins, f"{PATH_MIDI}/{model_name}_prompts/{prompt[:3]}_generated_midi_{i}.midi")
        
with open(f"{PATH_TOKENS}/{model_name}_prompts.json", "w") as fp:
        json.dump(data_generated, fp)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:90 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:90 for open-end generation.


midi saved in ../0_data/8_predictions/midi/d_prompts/040_generated_midi_0.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/040_generated_midi_1.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/040_generated_midi_2.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/040_generated_midi_3.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/040_generated_midi_4.midi
Number of incorrect notes: 0


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:90 for open-end generation.


midi saved in ../0_data/8_predictions/midi/d_prompts/095_generated_midi_0.midi
Number of incorrect notes: 1
midi saved in ../0_data/8_predictions/midi/d_prompts/095_generated_midi_1.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/095_generated_midi_2.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/095_generated_midi_3.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/095_generated_midi_4.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/600_generated_midi_0.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/600_generated_midi_1.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/600_generated_midi_2.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_predictions/midi/d_prompts/600_generated_midi_3.midi
Number of incorrect notes: 0
midi saved in ../0_data/8_pr

run "tar chvfz predictions_midi.tar.gz *" in terminal midi folder to create and download zip