In [1]:
import os
import numpy as np
import skimage.io
import cv2
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib

import yaml
from PIL import Image
from tqdm.notebook import tqdm

from train import WeldingConfig, WeldingDataset, setpath

In [2]:
# Setting

# Create model object in inference mode.
config = WeldingConfig()
MODEL_DIR = os.path.join(os.getcwd(), "logs")
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
 
# Prepare testing data
image_path = './small_flaw_data/'
dataset = WeldingDataset()
dataset.load_img(*setpath(image_path))
dataset.prepare()
total_imgs = len(dataset.image_info)
class_name = dataset.class_names

# Testing models index range
model_test_range=[40,60]

# The boundary of postive/negtive samples (=positive samples' number)
PN_boundary = 20

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Instructions for updating:
Use fn_output_signature instead


In [3]:
def predict(image_path, dataset):
    # Load ground_truth's information
    path = image_path + 'pic/' + name + '.png'
    gt_mask = dataset.load_mask(idx)[0]
    gt_class_id = dataset.class_names.index(dataset.from_yaml_get_class(idx)[0])

    # Prdict
    image = skimage.io.imread(path)
    r = model.detect([image], verbose=0)[0]
    pre_mask = r['masks']
    pre_class_id = r['class_ids']
    return gt_class_id, pre_class_id, gt_mask, pre_mask
    
def cal_iou(gt_mask, pre_mask):
    inter = np.count_nonzero(gt_mask+pre_mask==2)
    union = np.count_nonzero(gt_mask!=pre_mask)+inter
    return inter/union if union != 0 else 0

def get_max_iou(gt_class_id, pre_class_id,
                gt_mask, pre_mask):
    iou = 0
    for i, class_id in enumerate(pre_class_id):
        if class_id == gt_class_id:
            _pre_mask = pre_mask[:,:,[i]]
            iou = max(cal_iou(gt_mask,_pre_mask),iou)
    return iou

# Classificate samples
def classification(gt_class_id, pre_class_id, PN_boundary,
                   TP, TN, FP, FN):
    if gt_class_id in pre_class_id:
        inter = np.count_nonzero(gt_mask+pre_mask==2)
        union = np.count_nonzero(gt_mask!=pre_mask)+inter
        iou = inter/union if union != 0 else 0
        iou_list.append(iou)
        if idx < PN_boundary:
            TP+=1
        else:
            TN+=1
    else:
        if idx < PN_boundary:
            FP+=1
        else:
            FN+=1
    return TP, TN, FP, FN

# Ouput
def score(TP, TN, FP, FN):
    print('accuracy = ', (TP+TN)/(TP+FN+FP+TN))
    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    print('precision = ', precision)
    print('recall = ', recall)
    print('f1 score = ', 2*precision*recall/(precision+recall))

In [4]:
import warnings
warnings.filterwarnings('ignore')

for i in range(*model_test_range):
    TP,FP,TN,FN = 0,0,0,0
    iou_list=[]
    
    #Load weights
    model_id = ('0' if i < 9 else '') + str(i + 1)
    MODEL_PATH = os.path.join(MODEL_DIR, "welding20220531T1122/mask_rcnn_welding_00" + model_id + ".h5")
    model.load_weights(MODEL_PATH, by_name=True)
        
    #Set tqdm
    print('model', model_id)
    progress = tqdm(total=total_imgs)
    
    for idx in range(total_imgs):
        name = ('0' if idx < 9 else '') + str(idx + 1)
        
        gt_class_id, pre_class_id, gt_mask, pre_mask = predict(image_path, dataset)
        TP, TN, FP, FN = classification(gt_class_id, pre_class_id, PN_boundary,
                                        TP, TN, FP, FN)
        if gt_class_id in pre_class_id:
            iou_list.append(get_max_iou(gt_class_id, pre_class_id,
                                        gt_mask, pre_mask))
        progress.update(1)
    
    print('iou:',sum(iou_list)/len(iou_list),'max->',max(iou_list))
    score(TP, TN, FP, FN)
    print('-------------------------------------')

Re-starting from epoch 41
model 41


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.13839890163899407 max-> 0.6964285714285714
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 42
model 42


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.12017293845674662 max-> 0.6909090909090909
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 43
model 43


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.12421702451237199 max-> 0.65625
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 44
model 44


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.11083195640938746 max-> 0.5
accuracy =  0.7
precision =  0.7
recall =  1.0
f1 score =  0.8235294117647058
-------------------------------------
Re-starting from epoch 45
model 45


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.09255886378395738 max-> 0.5072463768115942
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 46
model 46


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.10969445201653823 max-> 0.5192307692307693
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 47
model 47


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.11473980955163496 max-> 0.45161290322580644
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 48
model 48


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.11659786257560932 max-> 0.5223880597014925
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 49
model 49


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.11848985844578097 max-> 0.5573770491803278
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 50
model 50


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.12355552094520648 max-> 0.5362318840579711
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 51
model 51


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.1164800082506679 max-> 0.7192982456140351
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 52
model 52


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.13306289929400997 max-> 0.6
accuracy =  0.7
precision =  0.7
recall =  1.0
f1 score =  0.8235294117647058
-------------------------------------
Re-starting from epoch 53
model 53


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.1304554033340903 max-> 0.547945205479452
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 54
model 54


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.15018948582497615 max-> 0.6805555555555556
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 55
model 55


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.10510433227564277 max-> 0.4
accuracy =  0.65
precision =  0.65
recall =  1.0
f1 score =  0.787878787878788
-------------------------------------
Re-starting from epoch 56
model 56


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.10721928729035708 max-> 0.509090909090909
accuracy =  0.75
precision =  0.75
recall =  1.0
f1 score =  0.8571428571428571
-------------------------------------
Re-starting from epoch 57
model 57


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.1467492738277491 max-> 0.6065573770491803
accuracy =  0.8
precision =  0.8
recall =  1.0
f1 score =  0.888888888888889
-------------------------------------
Re-starting from epoch 58
model 58


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.15109167753551067 max-> 0.7017543859649122
accuracy =  0.85
precision =  0.85
recall =  1.0
f1 score =  0.9189189189189189
-------------------------------------
Re-starting from epoch 59
model 59


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.14256703911044952 max-> 0.6197183098591549
accuracy =  0.8
precision =  0.8
recall =  1.0
f1 score =  0.888888888888889
-------------------------------------
Re-starting from epoch 60
model 60


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

iou: 0.1370322529505229 max-> 0.5409836065573771
accuracy =  0.8
precision =  0.8
recall =  1.0
f1 score =  0.888888888888889
-------------------------------------
