In [13]:
import os
import math

import numpy as np
import pandas as pd

In [99]:
class FaultMetrics:
    def __init__(self, dist_thresh=20):
        self.dist_thresh = dist_thresh

    def get_precision_recall_f1(self, gt, pred):
        tp, fp, fn = self.get_tp_fp_fn(gt, pred)
        
        if tp == 0:
            precision = 0
            recall = 0
            f1_score = 0
        else:
            precision = tp / (tp + fp)
            recall = tp / (tp + fn)
            f1_score = 2 * (precision * recall) / (precision + recall) 
        return precision, recall, f1_score

    def get_tp_fp_fn(self, gt, pred):
        tp = 0
        fp = 0
        fn = 0

        for i, fault in enumerate(pred):
            best_dist = 1000
            best_gt_idx = None

            for j, gt_point in enumerate(gt):
                if fault[2] == gt_point[2]:
                    distance = math.dist(fault[:2], gt_point[:2])
                    if distance < best_dist:
                        best_dist = distance
                        best_gt_idx = j

            if best_gt_idx is not None and best_dist <= self.dist_thresh:
                tp += 1
                del gt[best_gt_idx]
            else:
                fp += 1

        fn = len(gt)
        return tp, fp, fn
    
metrics_counter = FaultMetrics(dist_thresh=64)
gt = [[120,120, 0], [20, 190, 2], [234, 124, 6]]
pred = [[119,122,0], [21,190, 2], [342, 123, 3], [234, 1234, 4], [233, 121, 6]]
precision, recall, f1_score = metrics_counter.get_precision_recall_f1(gt, pred)
print(f'Precision: {precision} | Recall: {recall} | F1Score: {f1_score}')

Precision: 0.6 | Recall: 1.0 | F1Score: 0.7499999999999999


In [101]:
DATASET_ROOT = '/home/raid_storage/datasets/rosatom'
# GT_CSV_PATH = os.path.join(DATASET_ROOT, 'filtered_dataset.csv')
GT_CSV_PATH = 'gt_segments.csv'
# PR_CSV_PATH = './multipoint_yolo_predict.csv'
PR_CSV_PATH = 'pr_segments.csv'

In [102]:
gt_df = pd.read_csv(GT_CSV_PATH, index_col=0).sort_values(by=['filename'])
# gt_df = gt_df[gt_df['stage'] == 'test']
gt_df['img_path'] = gt_df['filename'].apply(lambda x: x.replace('/', '_').replace(' ', '_'))
gt_df

Unnamed: 0,filename,class,x,y,img_path
0,FRAMES/0/1538/frame0009.bmp,3,716,501,FRAMES_0_1538_frame0009.bmp
1,FRAMES/0/1538/frame0009.bmp,3,722,349,FRAMES_0_1538_frame0009.bmp
2,FRAMES/0/1538/frame0009.bmp,3,573,314,FRAMES_0_1538_frame0009.bmp
3,FRAMES/0/1538/frame0012.bmp,3,475,532,FRAMES_0_1538_frame0012.bmp
4,FRAMES/0/1538/frame0012.bmp,3,407,488,FRAMES_0_1538_frame0012.bmp
...,...,...,...,...,...
7883,FRAMES/2023.10.25/4_894.bmp,8,377,141,FRAMES_2023.10.25_4_894.bmp
7884,FRAMES/2023.10.25/5_486.bmp,5,809,308,FRAMES_2023.10.25_5_486.bmp
7885,FRAMES/2023.10.25/5_498.bmp,1,496,339,FRAMES_2023.10.25_5_498.bmp
7886,FRAMES/2023.10.25/5_809.bmp,8,265,375,FRAMES_2023.10.25_5_809.bmp


In [103]:
pr_df = pd.read_csv(PR_CSV_PATH).sort_values(by=['filename'])
pr_df

Unnamed: 0.1,Unnamed: 0,filename,class,x,y
0,0,FRAMES/0/1538/frame0009.bmp,3,711,497
1,1,FRAMES/0/1538/frame0009.bmp,3,573,314
2,2,FRAMES/0/1538/frame0012.bmp,3,415,494
3,3,FRAMES/0/1538/frame0012.bmp,3,390,365
4,4,FRAMES/0/1538/frame0012.bmp,3,278,327
...,...,...,...,...,...
4719,4719,FRAMES/2023.10.25/4_894.bmp,1,208,286
4720,4720,FRAMES/2023.10.25/5_486.bmp,5,810,308
4721,4721,FRAMES/2023.10.25/5_486.bmp,5,432,308
4722,4722,FRAMES/2023.10.25/5_498.bmp,8,548,117


In [130]:
metrics_counter = FaultMetrics(dist_thresh=64)
ignore_class = []

m_r = []
m_p = []
m_f = []


for i in range(0, 15):
    print(f'Class {i}:')
    selected_class = [i]

    u_images = gt_df['filename'].unique()
    precisions = []
    recalls = []
    f1 = []

    for u_image in u_images:
        gt_rows = gt_df[gt_df['filename'] == u_image]
        pr_rows = pr_df[pr_df['filename'] == u_image]

        gt_info = []
        for i, row in gt_rows.iterrows():
            if row['class'] in ignore_class:
                continue

            if row['class'] not in selected_class:
                continue

            gt_info.append([row['x'], row['y'], row['class']])
        
        if not gt_info:
            continue
            
        pr_info = []
        for i, row in pr_rows.iterrows():
            if row['class'] in ignore_class:
                continue

            if row['class'] not in selected_class:
                continue

            pr_info.append([row['x'], row['y'], row['class']])
        
        if not pr_info:
            continue
            
        p, r, f = metrics_counter.get_precision_recall_f1(gt_info, pr_info)

        precisions.append(p)
        recalls.append(r)
        f1.append(f)
    
    mean_p = np.mean(precisions)
    mean_r = np.mean(recalls)
    mean_f = np.mean(f1)
    
    if (mean_p and mean_r and mean_f) and (mean_p is not np.nan) and (mean_p is not None) and (mean_p != np.float('nan')):
        m_p.append(mean_p)
        m_r.append(mean_r)
        m_f.append(mean_f)

        print(f'Precision: {mean_p:2.2f} | Recall: {mean_r:2.2f} | F1Score: {mean_f:2.2f}')
    
    
print(f'Total mean:  Precision: {np.mean(m_p):2.2f} | Recall: {np.mean(m_r):2.2f} | F1Score: {np.mean(m_f):2.2f}')

Class 0:


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if (mean_p and mean_r and mean_f) and (mean_p is not np.nan) and (mean_p is not None) and (mean_p != np.float('nan')):


Precision: nan | Recall: nan | F1Score: nan
Class 1:


KeyboardInterrupt: 

In [127]:
np.float('nan')

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.float('nan')


nan

In [134]:
np.float('nan') is mean_p

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.float('nan') is mean_p


False

In [132]:
mean_p

nan