In [14]:
import os
import sys
curr_dir = os.getcwd()
if 's94zalek_hpc' in curr_dir:
    user_name = 's94zalek_hpc'
else:
    user_name = 's94zalek'
sys.path.append(f'/home/{user_name}/shape_matching')

# datasets
import my_code.diffusion_training.sample_model as sample_model
import my_code.diffusion_training.evaluate_samples as evaluate_samples


# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import yaml
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import my_code.datasets.template_dataset as template_dataset
import my_code.datasets.shape_dataset as shape_dataset

import my_code.diffusion_training.data_loading as data_loading



def get_subset(dataset, subset_fraction):
    
    # get n random samples
    n_samples = int(len(dataset) * subset_fraction)
    subset_indices = torch.randperm(len(dataset))[:n_samples]
    
    # return the subset
    return torch.utils.data.Subset(dataset, subset_indices), subset_indices


def plot_pck(metrics, title):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))

    thresholds = np.linspace(0., 0.1, 40)
    ax.plot(thresholds, torch.mean(metrics['pcks'], axis=0), 'r-',
            label=f'auc: {torch.mean(metrics["auc"]):.2f}')
    ax.set_xlim(0., 0.1)
    ax.set_ylim(0, 1)
    ax.set_xscale('linear')
    ax.set_xticks([0.025, 0.05, 0.075, 0.1])
    ax.grid()
    ax.legend()
    ax.set_title(title)
    return fig


def preprocess_metrics(metrics):
    metrics_payload = {}
    
    metrics_payload['auc'] = round(metrics['auc'].mean(dim=0).item(), 2)
    
    metrics_payload['geo_err_mean'] = round(metrics['geo_err_est'].mean().item() * 100, 1)
    metrics_payload['geo_err_ratio_mean'] = round(metrics['geo_err_ratio'].mean().item(), 2)
    metrics_payload['geo_err_ratio_median'] = round(metrics['geo_err_ratio'].median().item(), 2)
    metrics_payload['geo_err_ratio_max'] = round(metrics['geo_err_ratio'].max().item(), 2)
    metrics_payload['geo_err_ratio_min'] = round(metrics['geo_err_ratio'].min().item(), 2)
    
    metrics_payload['mse_mean'] = round(metrics['mse_abs'].mean().item(), 2)
    metrics_payload['mse_median'] = round(metrics['mse_abs'].median().item(), 2)
    metrics_payload['mse_max'] = round(metrics['mse_abs'].max().item(), 2)
    metrics_payload['mse_min'] = round(metrics['mse_abs'].min().item(), 2)
    
    return metrics_payload


In [None]:
experiment_name = 'test_faceScaling_faustRA'
checkpoint_name = 'checkpoint_60'
subset_fraction = 100
dataset_name = 'FAUST_orig'


### config
exp_base_folder = f'my_code/experiments/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
model = DiagConditionedUnet(config["model_params"]).to('cuda')
model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}.pt"))
model = model.to('cuda')



test_dataset = data_loading.get_val_dataset(
    dataset_name, 'train', config["model_params"]["sample_size"]
    )[1]
    

### sample the model
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True)
x_sampled = sample_model.sample_dataset(model, test_dataset, noise_scheduler)    


### assign gt signs and unnormalize the samples 
x_gt = torch.stack([test_dataset[i]['second']['C_gt_xy'] for i in range(len(test_dataset))])
fmap_sampled = torch.sign(x_gt) * (x_sampled + 1) / 2


### calculate metrics
metrics = evaluate_samples.calculate_metrics(
    fmap_sampled,
    test_dataset
)
metrics_payload = preprocess_metrics(metrics)
fig = plot_pck(metrics, title=f"PCK on {dataset_name}_{subset_fraction}")


### print the metrics
print(f"AUC mean: {metrics_payload['auc']}\n")
print(f'GeoErr mean: {metrics_payload["geo_err_mean"]}\n')
print(f"GeoErr ratio mean: {metrics_payload['geo_err_ratio_mean']}")
print(f"GeoErr ratio median: {metrics_payload['geo_err_ratio_median']}")
print(f'GeoErr ratio max: {metrics_payload["geo_err_ratio_max"]}', f'min: {metrics_payload["geo_err_ratio_min"]}\n')
print(f"MSE mean: {metrics_payload['mse_mean']}")
print(f"MSE median: {metrics_payload['mse_median']}")
print(f"MSE max: {metrics_payload['mse_max']}", f"min: {metrics_payload['mse_min']}")



fig.show()

