In [6]:
def get_error_ours(log_file, categories, name_category_map):

    # read the log file
    with open(log_file, 'r') as f:  
        lines = f.readlines()
    print(f'lines: {len(lines)}')

    # initialize error and times by category
    err_by_category = {category: 0 for category in categories}
    times_by_category = {category: 0 for category in categories}

    for line in lines:
        if 'Geodesic error' in line:
            
            # each line has form 2025-03-18 03:56:48 - INFO - Geodesic error for Standing2HMagicAttack01034 and InvertedDoubleKickToKipUp189: 5.2
            # get two file names and the error
            
            file_names = line.split('Geodesic error for ')[1].split(' and ')
            file_name_1 = file_names[0]
            file_name_2 = file_names[1].split(': ')[0]
            error = float(file_names[1].split(': ')[1])
            
            # get categories for both files, e.g. Standing2HMagicAttack01034 -> crypto
            category_1 = name_category_map[file_name_1]
            category_2 = name_category_map[file_name_2]
            
            # add error to both categories
            err_by_category[category_1] += error
            err_by_category[category_2] += error
            
            # add the number of times this category was seen
            times_by_category[category_1] += 1
            times_by_category[category_2] += 1
               
    # average error
    for category in categories:
        err_by_category[category] /= times_by_category[category] if times_by_category[category] > 0 else 1
        
    # get the categories sorted by error
    results_sorted = sorted(
        [(k, v) for k, v in err_by_category.items()],
        key=lambda x: x[1], reverse=False
    )
    categories_sorted = [item[0] for item in results_sorted]
    
    return categories_sorted, err_by_category, times_by_category

In [13]:
from test_on_dataset import get_dataset
import os

# model_name='ddpm_64_SMAL'
# dataset_name='SMAL_iso'

# model_name='ddpm_32'
# model_name='ddpm_64'
model_name='ddpm_96'
# dataset_name='DT4D_inter'
dataset_name='DT4D_intra'

log_file = f'checkpoints/ddpm/{model_name}/results_{dataset_name}.log'

single_dataset, _ = get_dataset(
    name=dataset_name,
    base_dir='/lustre/mlnvme/data/s94zalek_hpc-shape_matching/data_denoisfm/test',
)

# obtain a set of categories and a map from file name to category
categories = set()
name_category_map = {}

for off_file in single_dataset.off_files:
    
    if 'DT4D' in dataset_name:
        shape_category = off_file.split('/')[-2]
    elif 'SMAL' in dataset_name:
        shape_category = (off_file.split('/')[-1]).split('_')[0]
    else:
        raise ValueError(f'Unknown dataset: {dataset_name}')
    
    categories.add(shape_category)
    
    basename = os.path.splitext(os.path.basename(off_file))[0]
    
    if basename in name_category_map:
        raise ValueError(f'File {basename} already in the map')
    name_category_map[basename] = shape_category
    
print(f'Shape categories: {categories}')

categories_sorted, err_by_category_ours, times_by_category_ours = get_error_ours(
    log_file,
    categories,
    name_category_map,
)

print()
print(f'{log_file}')
for category in categories_sorted:
    print(f'{category}: {err_by_category_ours[category]:.1f}')
    
print()

mean_err_ours = 0
total_pairs = 0
for category in categories:
    mean_err_ours += err_by_category_ours[category] * times_by_category_ours[category]
    total_pairs += times_by_category_ours[category]
    
mean_err_ours /= total_pairs

print(f'mean error: {mean_err_ours:.1f}')
            