In [1]:
import os

os.chdir('/mnt/jw01-aruk-home01/projects/ra_challenge/RA_challenge/michael_dev/RA2_alpine_lads/ra_joint_predictions')

In [2]:
from utils.config import Config

config = Config()

In [3]:
from dataset.joint_val_dataset import hands_wrists_val_dataset

wrists_ds = hands_wrists_val_dataset(config)

wrist_j_data = wrists_ds.create_wrists_joints_dataset_with_validation(config.train_location + '/training.csv', erosion_flag = False)
wrist_e_data = wrists_ds.create_wrists_joints_dataset_with_validation(config.train_location + '/training.csv', erosion_flag = True)

[{'augment': <function random_brightness_and_contrast at 0x7f29c439ea60>}, {'augment': <function random_crop at 0x7f29c439eb70>, 'params': {'min_scale': 0.9}}, {'augment': <function random_gaussian_noise at 0x7f29c439ebf8>, 'p': 0.2}, {'augment': <function random_rotation at 0x7f29c439eae8>, 'params': {'max_rot': 10}}]
[{'augment': <function random_brightness_and_contrast at 0x7f29c439ea60>}, {'augment': <function random_crop at 0x7f29c439eb70>, 'params': {'min_scale': 0.9}}, {'augment': <function random_gaussian_noise at 0x7f29c439ebf8>, 'p': 0.2}, {'augment': <function random_rotation at 0x7f29c439eae8>, 'params': {'max_rot': 10}}]


In [4]:
import numpy as np

from tensorflow import keras
from model.utils.metrics import mae_metric, rmse_metric, class_filter_rmse_metric, softmax_mae_metric, softmax_rmse_metric, class_filter_softmax_rmse_metric

def eval_model(label, dataset, model_paths, max_output, steps, filter_model_paths = [], filter_cutoff = 0.5):
    truths = {
        0: [],
        1: [],
        2: [],
        3: [],
        4: [],
        5: []
    }
    
    preds = {
        0: [],
        1: [],
        2: [],
        3: [],
        4: [],
        5: []
    }
    
    models = [keras.models.load_model(model_path, compile = False) for model_path in model_paths]
    filter_models = [keras.models.load_model(filter_model_path, compile = False) for filter_model_path in filter_model_paths]
    
    for x, y in dataset.take(steps):
        y_preds = np.zeros((x.shape[0], 6))
        
        
        for model in models:
            y_pred = model.predict(x)
            
            for n in range(6):
                y_preds[:, n] += y_pred[n][:, 0]
                
        y_preds = y_preds / len(models)
        
        pred_filter = np.ones((x.shape[0], 6))
        if len(filter_models) > 0:
            filter_preds = np.zeros((x.shape[0], 6))
            
            for filter_model in filter_models:
                f_pred = filter_model.predict(x)
                
                for n in range(6):
                    filter_preds[:, n] += f_pred[n][:, 0]
                
            filter_preds = filter_preds / len(filter_models)
            
            pred_filter[filter_preds < filter_cutoff] = 0
            
        y_preds = np.multiply(y_preds, pred_filter)
       
        for n in range(6):
            truths[n].extend(y[n][:, 0].numpy())
            preds[n].extend(y_preds[:, n])
    
    loss = np.zeros(6)
    rmse = np.zeros(6)
    mae = np.zeros(6)
    filter_rmse = np.zeros(6)
            
    for idx in range(6):
        t_vals = np.array(truths[idx])
        p_vals = np.array(preds[idx])
        
        non0_idx = np.where(t_vals != 0.0)[0]
        
        loss[idx] = np.mean(np.square(t_vals - p_vals))
        rmse[idx] = np.sqrt(np.mean(np.square(t_vals - p_vals)))
        filter_rmse[idx] = np.sqrt(np.mean(np.square(t_vals[non0_idx] - p_vals[non0_idx])))
        mae[idx] = np.mean(np.absolute(t_vals - p_vals))
    
    print('Model:', label)
    print('Loss:', loss)
    print('MAE:', mae)
    print('RMSE:', rmse)
    print('Filter RMSE:', filter_rmse)
    

In [5]:
eval_model('Wrists Narrowing', wrist_j_data[1], ['../trained_models/narrowing/v7/wrists_narrowing_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_old_size_test.h5'], 4, 3)

Model: Wrists Narrowing
Loss: [0.41298287 0.49431016 0.23348539 0.6192735  0.403477   0.35051739]
MAE: [0.33075184 0.38960559 0.22725141 0.44991718 0.26787389 0.26593757]
RMSE: [0.64263743 0.70307194 0.48320326 0.78693932 0.63519839 0.59204509]
Filter RMSE: [1.14413947 1.15965203 1.06855476 1.11323853 1.57343069 1.38688393]


In [8]:
wrist_j_filter_paths = ['../trained_models/narrowing/v7/wrists_narrowing_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_test.h5',
'../trained_models/narrowing/v7/wrists_narrowing_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_left_group_test.h5',
'../trained_models/narrowing/v7/wrists_narrowing_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_right_group_test.h5']

wrist_j_paths = ['../trained_models/narrowing/v7/wrists_narrowing_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_right_group_full_model_1.h5',
'../trained_models/narrowing/v7/wrists_narrowing_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_left_group_full_model_1.h5',
'../trained_models/narrowing/v7/wrists_narrowing_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_full_model_1.h5']

In [9]:
eval_model('Wrists Narrowing', wrist_j_data[1], wrist_j_paths, 4, 3)

Model: Wrists Narrowing
Loss: [0.25692827 0.45194101 0.25832472 0.45354612 0.31924141 0.26828662]
MAE: [0.25618418 0.36225228 0.22777919 0.38731078 0.23672284 0.21678375]
RMSE: [0.50688092 0.67226558 0.50825655 0.67345833 0.56501452 0.51796392]
Filter RMSE: [0.9248765  1.14462449 1.14436223 1.04152911 1.41569497 1.28361224]


In [10]:
eval_model('Wrists Narrowing', wrist_j_data[1], wrist_j_paths, 4, 3, filter_model_paths = wrist_j_filter_paths, filter_cutoff = 0.1)

Model: Wrists Narrowing
Loss: [0.25269351 0.47315245 0.25753187 0.45785671 0.32811172 0.26600226]
MAE: [0.23144865 0.34715637 0.21658161 0.36589289 0.22263693 0.19502092]
RMSE: [0.5026863  0.68786078 0.50747598 0.6766511  0.57281037 0.51575407]
Filter RMSE: [0.93324716 1.17827627 1.14436223 1.04987136 1.44243138 1.29246049]


In [6]:
eval_model('Wrists Erosion (Old Shape)', wrist_e_data[1], ['../trained_models/erosion/v7/wrists_erosion_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_old_size_test.h5'], 5, 3)

Model: Wrists Erosion (Old Shape)
Loss: [0.80382921 0.67938874 1.02230631 0.80139358 0.9448288  1.11114931]
MAE: [0.31490099 0.39404771 0.36846846 0.5383482  0.35302002 0.47218565]
RMSE: [0.89656523 0.82425041 1.01109164 0.89520589 0.97202305 1.05411067]
Filter RMSE: [2.40432655 1.74506273 2.72983373 1.51914659 2.68067585 2.24064731]


In [5]:
erosion_models = ['../trained_models/erosion/v7/wrists_erosion_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_full_model_1.h5',
'../trained_models/erosion/v7/wrists_erosion_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_left_group_full_model_1.h5',
'../trained_models/erosion/v7/wrists_erosion_joint_damage_model_complex_rewritten_64bs_60steps_300epochs_adamW_1e3_1e6_new_shape_030maj_right_group_full_model_1.h5']

eval_model('Wrists Erosion (0.3 MAJ + R + L)', wrist_e_data[1], erosion_models, 5, 3)

Model: Wrists Erosion (0.3 MAJ + R + L)
Loss: [0.80560382 0.81774166 0.87464176 0.67659616 0.84454972 0.83854979]
MAE: [0.31893922 0.44191916 0.34811926 0.53394698 0.33815143 0.41196839]
RMSE: [0.89755436 0.90429069 0.93522284 0.82255466 0.91899386 0.91572364]
Filter RMSE: [2.4091458  1.90917123 2.45923003 1.41132039 2.52675087 1.93584804]


In [6]:
e_filter_model_paths = ['../trained_models/erosion/v7/wrists_erosion_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_right_group_test.h5',
'../trained_models/erosion/v7/wrists_erosion_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_left_group_test.h5',
'../trained_models/erosion/v7/wrists_erosion_joint_damage_type_model_complex_rewritten_gap_64bs_3normsteps_75epochs_adamW_3e4_1e6_new_shape_test.h5']

eval_model('Wrists Erosion (0.3 MAJ + R + L)', wrist_e_data[1], erosion_models, 5, 3, filter_model_paths = e_filter_model_paths, filter_cutoff = 0.1)

Model: Wrists Erosion (0.3 MAJ + R + L)
Loss: [0.80292114 0.81910066 0.88334319 0.66875179 0.8407279  0.85123031]
MAE: [0.29906113 0.4215763  0.33474956 0.49339566 0.31398794 0.38770162]
RMSE: [0.89605867 0.9050418  0.93986339 0.81777246 0.91691216 0.92262143]
Filter RMSE: [2.407548   1.9183853  2.47348167 1.41670982 2.52689935 1.95656902]
