In [5]:
import sys
import os
from os.path import exists

import json

from sklearn.metrics import r2_score

In [6]:
def match_json_files(folder1, folder2):
    matched_files = []

    if not os.path.isdir(folder1) or not os.path.isdir(folder2):
        raise ValueError("Both inputs must be valid directories.")

    folder1_files = {f for f in os.listdir(folder1) if f.endswith(".json")}
    folder2_files = {f for f in os.listdir(folder2) if f.endswith(".json")}

    common_files = folder1_files.intersection(folder2_files)

    for file in common_files:
        path1 = os.path.join(folder1, file)
        path2 = os.path.join(folder2, file)
        matched_files.append((path1, path2))

    return matched_files

def load_validation_data(file_path, translations : list, rotations: list):
    with open(file_path, 'r') as f:
        data = json.load(f)

    object_in_image = False

    for i_line in range(len(data['objects'])):
        info = data['objects'][i_line]

        if "Ketchup".lower() not in info['class'].lower():
            continue                     

        # Parse translations
        location = info.get('location', [0, 0, 0])
        translations.extend(location)

        # Parse quaternion
        rot = info.get("quaternion_xyzw", [0, 0, 0, 1])
        rotations.extend(rot)
        object_in_image = True
    
    if not object_in_image:
        translations.extend([0,0,0])
        rotations.extend([0,0,0,1])     

In [7]:
validation_folder = r"C:\github\POSEIDON\dataset\test_frame_images"
predictions_folder = r"C:\github\POSEIDON\output\net_batchweights_ketchup_500"

matched_files = match_json_files(validation_folder, predictions_folder)

val_translations = []
val_rotations = []

pred_translations = []
pred_rotations = []

for i, (validation_file, prediction_file) in enumerate(matched_files):    
    load_validation_data(validation_file, val_translations, val_rotations)
    load_validation_data(prediction_file, pred_translations, pred_rotations)

pred_translations = [x / 100 for x in pred_translations]

print("Validation Translations:", len(val_translations))
print("Validation Rotations:", len(val_rotations))
print("Prediction Translations:", len(pred_translations))
print("Prediction Rotations:", len(pred_rotations))


print("Validation Translations:", val_translations)
print("Validation Rotations:", val_rotations)
print("Prediction Translations:", pred_translations)
print("Prediction Rotations:", pred_rotations)

Validation Translations: 63
Validation Rotations: 84
Prediction Translations: 63
Prediction Rotations: 84
Validation Translations: [-0.099363312125206, 0.006707297638058662, 0.5613508820533752, 0.41922274231910706, 0.15390260517597198, 1.9022607803344727, 0.10910351574420929, 0.051370102912187576, 0.5766260623931885, 0.06977595388889313, 0.0055173118598759174, 0.6518203020095825, -0.03060629777610302, 0.26856061816215515, 1.6560226678848267, 0.29005536437034607, 0.045599211007356644, 1.8884477615356445, 0.08595508337020874, 0.21126225590705872, 1.0031615495681763, 0.2388124316930771, 0.09952881187200546, 1.8451529741287231, 0.043604526668787, 0.23469677567481995, 1.705459713935852, 0.3615546226501465, 0.0398947149515152, 1.9613964557647705, 0, 0, 0, 0.13804328441619873, 0.2839696407318115, 1.7495462894439697, 0.12310633808374405, 0.1988159418106079, 1.752760887145996, 0.39508888125419617, 0.03898092731833458, 1.9541226625442505, 0.1892232894897461, 0.15322816371917725, 1.80016016960144

In [8]:
# Calculate R² score
r2_translation = r2_score(val_translations, pred_translations)
r2_rotation = r2_score(val_rotations, pred_rotations)
print("R² score for translations:", r2_translation)
print("R² score for rotations:", r2_rotation)

R² score for translations: 0.7408044460434837
R² score for rotations: -2.842170963325525
