In [1]:
import json
import tensorflow as tf
import os
import pandas as pd
from DataGenerator import DataGenerator
from metrics import *

In [6]:
OCTAVES = 8
TEST_INDEX = 0

In [7]:
model_folder = f"saved/2024-03-30/0_1_{OCTAVES}_octaves"
model_path = f"{model_folder}/model"
checkpoint_path = f"{model_folder}/checkpoints"
model = tf.keras.models.load_model(model_path)

In [8]:
ids_path = "../../data/archived/id_22050.csv"
list_IDs = list(pd.read_csv(ids_path, header=None)[0])
testing_ids = [id for id in list_IDs if int(id.split("_")[0]) == TEST_INDEX]

In [9]:
test_generator = DataGenerator(testing_ids,
                               batch_size=400,
                               data_path=f"../../data/archived/GuitarSet/{OCTAVES}_octaves/",
                               shuffle=False,
                               con_win_size=9)

In [10]:
print(testing_ids)

['00_Jazz3-137-Eb_solo_0', '00_Jazz3-137-Eb_solo_1', '00_Jazz3-137-Eb_solo_2', '00_Jazz3-137-Eb_solo_3', '00_Jazz3-137-Eb_solo_4', '00_Jazz3-137-Eb_solo_5', '00_Jazz3-137-Eb_solo_6', '00_Jazz3-137-Eb_solo_7', '00_Jazz3-137-Eb_solo_8', '00_Jazz3-137-Eb_solo_9', '00_Jazz3-137-Eb_solo_10', '00_Jazz3-137-Eb_solo_11', '00_Jazz3-137-Eb_solo_12', '00_Jazz3-137-Eb_solo_13', '00_Jazz3-137-Eb_solo_14', '00_Jazz3-137-Eb_solo_15', '00_Jazz3-137-Eb_solo_16', '00_Jazz3-137-Eb_solo_17', '00_Jazz3-137-Eb_solo_18', '00_Jazz3-137-Eb_solo_19', '00_Jazz3-137-Eb_solo_20', '00_Jazz3-137-Eb_solo_21', '00_Jazz3-137-Eb_solo_22', '00_Jazz3-137-Eb_solo_23', '00_Jazz3-137-Eb_solo_24', '00_Jazz3-137-Eb_solo_25', '00_Jazz3-137-Eb_solo_26', '00_Jazz3-137-Eb_solo_27', '00_Jazz3-137-Eb_solo_28', '00_Jazz3-137-Eb_solo_29', '00_Jazz3-137-Eb_solo_30', '00_Jazz3-137-Eb_solo_31', '00_Jazz3-137-Eb_solo_32', '00_Jazz3-137-Eb_solo_33', '00_Jazz3-137-Eb_solo_34', '00_Jazz3-137-Eb_solo_35', '00_Jazz3-137-Eb_solo_36', '00_Jazz3-

In [11]:
def test(model, test_ids):
    y_gt = np.empty((len(test_ids), 6, 21))
    y_pred = np.empty((len(test_ids), 6, 21))
    index = 0
    for i in range(len(test_generator)):
        X_test, gt = test_generator[i]
        pred = model.predict(X_test, verbose=0)
        size = len(pred[0])
        for sample_index in range(size):
            EString_pred = pred[0][sample_index]
            AString_pred = pred[1][sample_index]
            DString_pred = pred[2][sample_index]
            GString_pred = pred[3][sample_index]
            BString_pred = pred[4][sample_index]
            eString_pred = pred[5][sample_index]
            sample_tab_pred = np.array([EString_pred, AString_pred, DString_pred, GString_pred, BString_pred, eString_pred])
            y_pred[index,] = sample_tab_pred

            EString_gt = gt["EString"][sample_index]
            AString_gt = gt["AString"][sample_index]
            DString_gt = gt["DString"][sample_index]
            GString_gt = gt["GString"][sample_index]
            BString_gt = gt["BString"][sample_index]
            eString_gt = gt["eString"][sample_index]
            sample_tab_gt = np.array([EString_gt, AString_gt, DString_gt, GString_gt, BString_gt, eString_gt])
            y_gt[index,] = sample_tab_gt
            index += 1
    return y_gt, y_pred


In [12]:
def save_predictions(predictions, ground_truths, save_path, filename):
    np.savez(f"{save_path}/{filename}.npz", y_pred = predictions, y_gt = ground_truths)

In [13]:
def evaluate(predictions, ground_truths):
    metrics = {
        "pitch_precision" : [],
        "pitch_recall" : [],
        "pitch_f_score" : [],
        "tab_precision" : [],
        "tab_recall" : [],
        "tab_f_score" : []
    }
    metrics["pitch_precision"].append(pitch_precision(predictions, ground_truths))
    metrics["pitch_recall"].append(pitch_recall(predictions, ground_truths))
    metrics["pitch_f_score"].append(pitch_f_score(predictions, ground_truths))
    metrics["tab_precision"].append(tab_precision(predictions, ground_truths))
    metrics["tab_recall"].append(tab_recall(predictions, ground_truths))
    metrics["tab_f_score"].append(tab_f_measure(predictions, ground_truths))
    return metrics

In [14]:
def test_epochs(model, checkpoints_path):
    result = {}
    for i, epoch_file in enumerate(os.listdir(checkpoints_path)):
        npz_already_created = f"predictions_{i+1}_epoch.npz" in os.listdir(model_folder)
        model.load_weights(f"{checkpoints_path}/{epoch_file}")
        #check if npz file already exists
        if npz_already_created:
            print(f"Loading epoch {i+1}...")
            archive = np.load(f"{model_folder}/predictions_{i+1}_epoch.npz")
            y_gt = archive["y_gt"]
            y_pred = archive["y_pred"]
        else:
            print(f"Testing epoch {i+1}...")
            y_gt, y_pred = test(model, testing_ids)
        print(f"Loaded predictions...")
        if not npz_already_created:
            print("Saving predictions...")
            save_predictions(y_pred, y_gt, model_folder, f"predictions_{i+1}_epoch")
        print(f"Evaluating...")
        metrics = evaluate(y_pred, y_gt)
        print(metrics)
        result[str(i+1)] = metrics
    print("Dumping in json file...")
    with open(f"{model_folder}/pred_by_epoch.json", "w") as f:
        json.dump(result, f, indent=4)

In [15]:
test_epochs(model, checkpoint_path)

Testing epoch 1...
Loaded predictions...
Saving predictions...
Evaluating...
{'pitch_precision': [0.8606242552936556], 'pitch_recall': [0.7030358725618675], 'pitch_f_score': [0.7738890485083562], 'tab_precision': [0.6521725758433187], 'tab_recall': [0.574592731490012], 'tab_f_score': [0.6109296040841963]}
Testing epoch 2...
Loaded predictions...
Saving predictions...
Evaluating...
{'pitch_precision': [0.8934339570146762], 'pitch_recall': [0.6984681975628191], 'pitch_f_score': [0.7840120120486707], 'tab_precision': [0.7028805886420268], 'tab_recall': [0.5839574684861426], 'tab_f_score': [0.6379238893625083]}
Testing epoch 3...
Loaded predictions...
Saving predictions...
Evaluating...
{'pitch_precision': [0.9085140526772963], 'pitch_recall': [0.6711110869435185], 'pitch_f_score': [0.7719728410274372], 'tab_precision': [0.7442453607661486], 'tab_recall': [0.5677209221664625], 'tab_f_score': [0.6441075018906904]}
Testing epoch 4...
Loaded predictions...
Saving predictions...
Evaluating...
