# Evaluate MIDI base metrics

In [1]:
import music21 as ms21
import pretty_midi
import numpy as np
import torch
import glob
from tqdm import tqdm

In [11]:
#-----set data (gt is npy files, pred is pt file)-----
ground_truth_dir = "/raid/koki-sakurai/model/dataset/data/npy/classical_jazz_1024_textcond/train"
ground_truth_pathes = glob.glob(ground_truth_dir + "/*.npy")
prediction_path = "/raid/koki-sakurai/model/train/pretrained/sample/finetuning-1600-100/sample_10_B.pt"

#-----load data-----
ground_truths = []
for path in tqdm(ground_truth_pathes, desc="loading ground truth"):
    ground_truth_tmp = np.load(path, allow_pickle=True).item()["midi"]
    ground_truths.append(ground_truth_tmp)
predictions = torch.load(prediction_path, map_location="cpu")["midi"].detach().numpy()


loading ground truth: 100%|██████████| 2022/2022 [00:02<00:00, 730.37it/s]


Get features

In [4]:
def convert_to_midi(vector, note_split):
    """
    midiに変換
    
    params
    -------------
    vector: numpy.ndarray
        変換するmany-hot vector(チャンネル数, 長さ, 128)
    note_split: int
        分割する音符の種類
    
    returns
    ---------
    piano: object
        midi_instrumentデータ
    """
    vector = ((vector +1)*63.5).clip(0, 127)
    #vector[vector >0.5] = 70
    shapes = vector.shape
    instrument_name = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
    piano = pretty_midi.Instrument(program=instrument_name)

    #パラメータ
    start_time = 0.0
    end_time = 0.0
    tempo = 120
    time_split = 60/(tempo*note_split/4)

    #ノートを追加
    for p in range(shapes[-1]):
        velocity = 0
        start_time = 0.0
        end_time = 0.0
        durting = False #音が伸びているかどうか
        on = 0 #音が鳴っているかどうか
        for t in range(shapes[-2]):
            if not durting:
                on = int(round(vector[1][t][p]))
                velocity = int(round(vector[0][t][p]))
            end_time += time_split
    
            #次の時間において伸びているかの処理
            if on > 0:
                if t != shapes[-2]-1:
                    sutain = int(round(vector[0][t+1][p]))
                    if sutain > 0 and int(round(vector[1][t+1][p])) < 1:
                        durting = True
                        continue
                    else:
                        durting = False
            else:
                start_time = end_time
                durting = False
                continue      
    
            #add pitch
            note = pretty_midi.Note(
                velocity=velocity,
                pitch=int(p),
                start=start_time, 
                end=end_time
            )
            piano.notes.append(note)
            start_time = end_time
    
    return piano

def process_midi(path):
    """
    process midi file to notes, bar, pattern, pattern_bar
    """
    s=ms21.converter.parse(path)
    note2id = {
        "C": 0, 
        "C#": 1, 
        "D": 2, 
        "E-": 3, 
        "E": 4, 
        "F": 5, 
        "F#": 6, 
        "G": 7, 
        "G#": 8, 
        "A": 9, 
        "B-": 10,
        "B": 11, 
    }

    notes, bar, pattern, pattern_bar = [], [], [], [0 for k in range(16)]
    lastoffset = 0
    i, j = 1, 0
    bar_step = 32 #1 bar per 4

    for note in s.flat.notesAndRests:
        if isinstance(note, ms21.note.Rest):
            continue
        #get notes perbar step
        if note.offset >= bar_step * i and lastoffset < bar_step * i:
            notes.append(bar)
            bar = []
            i += 1
        #get pattern per bar
        while note.offset >= 4 * (j + 1) and lastoffset < 4 * (j + 1):
            pattern.append(np.array(pattern_bar))
            pattern_bar = [0 for k in range(16)]
            j += 1
        #append bar and pattern with notes
        if isinstance(note,ms21.note.Note):
            # print(note.name, note.octave, note.pitch, note.pitch.midi, note.duration.quarterLength)
            bar.append(note2id[note.name])
            pattern_bar[int(4*(note.offset-4*j))] += 1
        else:
            try:
                for c_note in note.notes:
                    # print(c_note.name, c_note.pitch.midi, c_note.duration.quarterLength)
                    bar.append(note2id[c_note.name])
            except:
                pass
            pattern_bar[int(4*(note.offset-4*j))] += 1
            
        lastoffset = note.offset
    
    #append last bar step and bar
    notes.append(bar)
    pattern.append(np.array(pattern_bar))

    return notes, bar, pattern, pattern_bar

def get_features(melody):
    """
    get features from melody
    
    params
    -------------
    melody: np.ndarray #melody data

    returns
    ---------
    features: list #features
    """
    features = []
    for data in tqdm(melody, desc="get features"):
        #convert to midi
        piano = convert_to_midi(data, 128)
        midi = pretty_midi.PrettyMIDI()
        midi.instruments.append(piano)
        midi.write("tmp.mid")
        #process midi
        notes, bar, pattern, pattern_bar = process_midi("tmp.mid")
        features.append([notes, bar, pattern, pattern_bar])
    
    return features

In [5]:
#-----get ground truth features-----
print("get ground truth features")
ground_truth_features = get_features(ground_truths)

get ground truth features


  return self.iter().getElementsByClass(classFilterList)
get features: 100%|██████████| 2022/2022 [19:37<00:00,  1.72it/s]


In [12]:
#-----get prediction features-----
print("get prediction features")
prediction_features = get_features(predictions)

get prediction features


get features: 100%|██████████| 100/100 [00:54<00:00,  1.83it/s]


Evaluate PCHE and GPS

In [13]:
def pitch_class_histogram_entropy(notes):
    
    result = []

    for bar in notes:
        # Construct the 12-dimensional pitch class histogram
        histogram = np.zeros(12)
        for note in bar:
            pitch_class = note % 12
            histogram[pitch_class] += 1

        # Normalize the histogram
        histogram = histogram / np.sum(histogram)

        # Calculate the entropy
        entropy = -np.sum(histogram * np.log2(histogram + 1e-6))  # Added epsilon to avoid log(0)
        result.append(entropy)

    return sum(result)/len(notes)

def grooving_pattern_similarity(g_a, g_b):
    assert len(g_a) == len(g_b), "Grooving patterns must have the same length"
    Q = len(g_a)
    gs = 1 - (1/Q) * np.sum(np.bitwise_xor(g_a, g_b))
    return gs

def cal_gps(pattern):

    results = []
    for i in range(len(pattern)):
        for j in range(i + 1, len(pattern)):
            g_a, g_b = pattern[i], pattern[j]
            results.append(grooving_pattern_similarity(g_a, g_b))

    return sum(results) / len(results)

#-----calculate PCHE and GPS-----
pche_ground_truth = []
gps_ground_truth = []
pche_prediction = []
gps_prediction = []

for feature in ground_truth_features:
    pche_ground_truth.append(pitch_class_histogram_entropy(feature[0]))
    gps_ground_truth.append(cal_gps(feature[2]))

for feature in prediction_features:
    pche_prediction.append(pitch_class_histogram_entropy(feature[0]))
    gps_prediction.append(cal_gps(feature[2]))

#-----print result-----
print("PCHE ground truth", np.mean(pche_ground_truth))
print("GPS ground truth", np.mean(gps_ground_truth))
print("PCHE prediction", np.mean(pche_prediction))
print("GPS prediction", np.mean(gps_prediction))

PCHE ground truth 2.9291662676669983
GPS ground truth 0.6097825757877687
PCHE prediction 2.844852203702743
GPS prediction 0.5061830357142858
