mask.npyを読み込んでセグメンテーションの評価を行う

In [8]:
#関数定義
import os, sys
import numpy as np

def obj_detection(mask, class_id:int):
    """
    Input:
        mask : [width, height](ndarray), image data
        class_id : int , class id(ex : 1day -> 1)
    Return:
        mask : [object num(int), width(int), height(int)]
        cls_idxs : [nobject num(int)]
    """
    data = mask
    labels = []
    for label in np.unique(data):
        #: ラベルID==0は背景
        if label == 0:
            continue
        else:
            labels.append(label)

    if len(labels) == 0:
        #: 対象オブジェクトがない場合はNone
        return None, None
    else:
        mask = np.zeros((mask.shape)+(len(labels),), dtype=np.uint8)
        for n, label in enumerate(labels):
            mask[:, :, n] = np.uint8(data == label)
        cls_idxs = np.ones([mask.shape[-1]], dtype=np.int32) * class_id

        return mask.transpose(2, 0 ,1), cls_idxs

################################################################################### 
###################################################################################     
#F1スコアの計算#   
###################################################################################   
###################################################################################   
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

def calculate_iou(mask1, mask2):
    """
    Calculate Intersection over Union (IoU) between two masks.

    Parameters:
    - mask1, mask2: Binary masks to calculate IoU.

    Returns:
    - IoU value.
    """
    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou = np.sum(intersection) / np.sum(union)
    return iou

def calculate_f1_score(inf_masks, true_masks, iou_threshold=0.5):
    """
    Calculate F1 score based on IoU matching between inference masks and true masks.

    Parameters:
    - inf_masks: Inference masks.
    - true_masks: True masks.
    - iou_threshold: IoU threshold for considering a match.

    Returns:
    - Average F1 score.
    """
    def find_best_iou_match(inf_mask, true_masks, iou_threshold):
        """
        Find the best IoU match for a given inference mask among true masks.

        Parameters:
        - inf_mask: Inference mask.
        - true_masks: List of true masks.
        - iou_threshold: IoU threshold for considering a match.

        Returns:
        - True if a match is found, False otherwise.
        """
        best_iou = 0
        for true_mask in true_masks:
            iou = calculate_iou(inf_mask, true_mask)
            best_iou = max(best_iou, iou)
        return best_iou >= iou_threshold

    f1_scores = []

    for inf_mask in inf_masks:
        matching_indices = [find_best_iou_match(inf_mask, true_masks, iou_threshold)]
        precision, recall, f1, _ = precision_recall_fscore_support(np.array([True]), np.array(matching_indices), average='binary')
        f1_scores.append(f1)

    average_f1 = np.mean(f1_scores)

    return average_f1

Path Define

In [10]:
from glob import glob
import os
import numpy as np

INF_FOLDER = "./Valid_outpur_MyoSothes"
TRUE_PATH = "./Valid_GroundTrue"
days = ["0day", "3day", "5day", "7day", "11day", "14day"]
#days = ["11day", "14day"]




for day in days:

    true_files = glob(os.path.join(TRUE_PATH , day, "*.npy"))
    data = np.zeros((len(true_files), 2))#np.zeros((len(true_files), 2))#F1scoreとbbox数を保存する
    print("***********************************************")
    print("patch num:", len(true_files))
    
    for i, file in enumerate(true_files):
        inf_path = os.path.join(INF_FOLDER , day, os.path.basename(file))
        true_path = file 
        
        try:
            inf_mask = np.load(inf_path)
            true_mask = np.load(true_path)

            inf_masks, _ = obj_detection(inf_mask, 1)
            true_masks, _ = obj_detection(true_mask, 1)
            
            # Example: Calculate F1 score with IOU threshold 0.5
            f1_score = calculate_f1_score(inf_masks, true_masks, iou_threshold=0.5)
            #print("F1 Score:", f1_score)
            data[i][0] = f1_score
            data[i][1] = len(true_masks)
            
        except:
            print(inf_path)
            data[i][0] = 0
            data[i][1] = 0
            pass
    
    #dayごとのデータ保存
    np.save(day, data)




***********************************************
patch num: 180
./Valid_outpur_MyoSothes/0day/178CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/177CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/176CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/175CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/0CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/100CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/101CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/102CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/104CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/105CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/106CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/107CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/108CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/109CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/10CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/110CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/111CTX day0.tiff.npy
./Valid_outpur_MyoSothes/0day/112CTX day0.tiff.npy
./Valid_outpur_MyoSoth

In [26]:

    inf_mask = np.load(INF_PATH)
    print(inf_mask.shape)
    print(np.unique(inf_mask))

    true_mask = np.load(TRUE_PATH)
    print(true_mask.shape)
    print(np.unique(true_mask))

    import matplotlib.pyplot as plt
    plt.matshow(inf_mask)
    plt.matshow(true_mask)

    inf_masks, _ = obj_detection(inf_mask, 1)
    true_masks, _ = obj_detection(true_mask, 1)
    print(inf_masks.shape)
    print(true_masks.shape)
    
        # Example: Calculate F1 score with IOU threshold 0.5
    f1_score = calculate_f1_score(inf_masks, true_masks, iou_threshold=0.5)
    print("F1 Score:", f1_score)



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


F1 Score: 0.9117647058823529
