In [None]:
import os
import json

from sklearn.metrics import r2_score

In [3]:
def match_json_files(folder1, folder2):
    matched_files = []

    if not os.path.isdir(folder1) or not os.path.isdir(folder2):
        raise ValueError("Both inputs must be valid directories.")

    folder1_files = {f for f in os.listdir(folder1) if f.endswith(".json")}
    folder2_files = {f for f in os.listdir(folder2) if f.endswith(".json")}

    common_files = folder1_files.intersection(folder2_files)

    for file in common_files:
        path1 = os.path.join(folder1, file)
        path2 = os.path.join(folder2, file)
        matched_files.append((path1, path2))

    return matched_files

def load_validation_data(file_path, translations : list, rotations: list):
    with open(file_path, 'r') as f:
        data = json.load(f)

    object_in_image = False

    for i_line in range(len(data['objects'])):
        info = data['objects'][i_line]

        if "Ketchup".lower() not in info['class'].lower():
            continue                     

        # Parse translations
        location = info.get('location', [0, 0, 0])
        translations.extend(location)

        # Parse quaternion
        rot = info.get("quaternion_xyzw", [0, 0, 0, 1])
        rotations.extend(rot)
        object_in_image = True
    
    if not object_in_image:
        translations.extend([0,0,0])
        rotations.extend([0,0,0,1])     

In [4]:
from scipy.spatial.transform import Rotation as R
import numpy as np

validation_folder = r"C:\github\POSEIDON\dataset\test_frame_images"
predictions_folder = r"C:\github\POSEIDON\output\net_batchweights_ketchup_500"

matched_files = match_json_files(validation_folder, predictions_folder)

val_translations = []
val_rotations = []

pred_translations = []
pred_rotations = []

for i, (validation_file, prediction_file) in enumerate(matched_files):    
    load_validation_data(validation_file, val_translations, val_rotations)
    load_validation_data(prediction_file, pred_translations, pred_rotations)

pred_translations = [x / 100 for x in pred_translations]

print("Validation Translations:", len(val_translations))
print("Validation Rotations:", len(val_rotations))
print("Prediction Translations:", len(pred_translations))
print("Prediction Rotations:", len(pred_rotations))


transformation_matrix = np.array([[1, 0, 0],
                                   [0, 0, -1],
                                   [0, 1, 0]])

transformed_rotations = []

for i in range(0, len(pred_rotations), 4):
    quat = pred_rotations[i:i+4]
    r = R.from_quat(quat)
    rotation_matrix = r.as_matrix()
    transformed_matrix = transformation_matrix @ rotation_matrix  
    transformed_quat = R.from_matrix(transformed_matrix).as_quat()
    transformed_rotations.extend(transformed_quat)

pred_rotations = transformed_rotations

print("Validation Translations:", val_translations)
print("Validation Rotations:", val_rotations)
print("Prediction Translations:", pred_translations)
print("Prediction Rotations:", pred_rotations)

Validation Translations: 63
Validation Rotations: 84
Prediction Translations: 63
Prediction Rotations: 84
Validation Translations: [-0.03060629777610302, 0.26856061816215515, 1.6560226678848267, -0.099363312125206, 0.006707297638058662, 0.5613508820533752, 0.2388124316930771, 0.09952881187200546, 1.8451529741287231, 0.4185505211353302, 0.17979469895362854, 1.9451348781585693, 0.13804328441619873, 0.2839696407318115, 1.7495462894439697, 0.08595508337020874, 0.21126225590705872, 1.0031615495681763, 0.3615546226501465, 0.0398947149515152, 1.9613964557647705, -0.09741877019405365, -0.09579642862081528, 0.48916372656822205, 0.41922274231910706, 0.15390260517597198, 1.9022607803344727, 0.1892232894897461, 0.15322816371917725, 1.8001601696014404, -0.14410650730133057, 0.12787561118602753, 1.433038592338562, 0.043604526668787, 0.23469677567481995, 1.705459713935852, -0.28109052777290344, 0.2908279001712799, 1.6075611114501953, 0, 0, 0, 0.12310633808374405, 0.1988159418106079, 1.752760887145996

In [5]:
# Calculate R² score
r2_translation = r2_score(val_translations, pred_translations)
r2_rotation = r2_score(val_rotations, pred_rotations)
print("R² score for translations:", r2_translation)
print("R² score for rotations:", r2_rotation)

R² score for translations: 0.8420014640386587
R² score for rotations: -3.522203330058897
