In [1]:
import os
import numpy as np
import skimage.io
import cv2
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib

import yaml
from PIL import Image
from tqdm.notebook import tqdm

from train import WeldingConfig, WeldingDataset, setpath

In [2]:
# Setting

# Create model object in inference mode.
config = WeldingConfig()
MODEL_DIR = os.path.join(os.getcwd(), "logs")
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
 
# Prepare testing data
image_path = './total_data/'
dataset = WeldingDataset()
dataset.load_img(*setpath(image_path))
dataset.prepare()
total_imgs = len(dataset.image_info)
class_name = dataset.class_names

# Testing models index range
model_test_range=[35,60]

# The boundary of postive/negtive samples (=positive samples' number)
PN_boundary = 71

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Instructions for updating:
Use fn_output_signature instead


In [3]:
def predict(image_path, dataset):
    # Load ground_truth's information
    path = image_path + 'pic/' + name + '.png'
    gt_mask = dataset.load_mask(idx)[0]
    gt_class_id = dataset.class_names.index(dataset.from_yaml_get_class(idx)[0])

    # Prdict
    image = skimage.io.imread(path)
    r = model.detect([image], verbose=0)[0]
    pre_mask = r['masks']
    pre_class_id = r['class_ids']
    return gt_class_id, pre_class_id, gt_mask, pre_mask
    
def cal_iou(gt_mask, pre_mask):
    inter = np.count_nonzero(gt_mask+pre_mask==2)
    union = np.count_nonzero(gt_mask!=pre_mask)+inter
    return inter/union if union != 0 else 0

def get_max_iou(gt_class_id, pre_class_id,
                gt_mask, pre_mask):
    iou = 0
    for i, class_id in enumerate(pre_class_id):
        if class_id == gt_class_id:
            _pre_mask = pre_mask[:,:,[i]]
            iou = max(cal_iou(gt_mask,_pre_mask),iou)
    return iou

# Classificate samples
def classification(gt_class_id, pre_class_id, PN_boundary,
                   TP, TN, FP, FN):
    if gt_class_id in pre_class_id:
        inter = np.count_nonzero(gt_mask+pre_mask==2)
        union = np.count_nonzero(gt_mask!=pre_mask)+inter
        iou = inter/union if union != 0 else 0
        iou_list.append(iou)
        if idx < PN_boundary:
            TP+=1
        else:
            TN+=1
    else:
        if idx < PN_boundary:
            FP+=1
        else:
            FN+=1
    return TP, TN, FP, FN

# Ouput
def score(TP, TN, FP, FN):
    print('accuracy = ', (TP+TN)/(TP+FN+FP+TN))
    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    print('precision = ', precision)
    print('recall = ', recall)
    print('f1 score = ', 2*precision*recall/(precision+recall))

In [None]:
import warnings
warnings.filterwarnings('ignore')

for i in range(*model_test_range):
    TP,FP,TN,FN = 0,0,0,0
    iou_list=[]
    
    #Load weights
    model_id = ('0' if i < 9 else '') + str(i + 1)
    MODEL_PATH = os.path.join(MODEL_DIR, "welding20220531T1122/mask_rcnn_welding_00" + model_id + ".h5")
    model.load_weights(MODEL_PATH, by_name=True)
        
    #Set tqdm
    print('model', model_id)
    progress = tqdm(total=total_imgs)
    
    for idx in range(total_imgs):
        name = ('0' if idx < 9 else '') + str(idx + 1)
        
        gt_class_id, pre_class_id, gt_mask, pre_mask = predict(image_path, dataset)
        TP, TN, FP, FN = classification(gt_class_id, pre_class_id, PN_boundary,
                                        TP, TN, FP, FN)
        if gt_class_id in pre_class_id:
            iou_list.append(get_max_iou(gt_class_id, pre_class_id,
                                        gt_mask, pre_mask))
        progress.update(1)
    
    print('iou:',sum(iou_list)/len(iou_list),'max->',max(iou_list))
    score(TP, TN, FP, FN)
    print('-------------------------------------')

Re-starting from epoch 36
model 36


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7066181376451062 max-> 0.9207042626423088
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 37
model 37


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.718340380521789 max-> 0.9204434300619498
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 38
model 38


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7157689151080422 max-> 0.9236432477913337
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 39
model 39


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.6938895404764261 max-> 0.9288032893313136
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 40
model 40


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7004131050953856 max-> 0.915359477124183
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 41
model 41


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7149987321933211 max-> 0.9070623916811091
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 42
model 42


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7008195667871435 max-> 0.9159414556962026
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 43
model 43


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.710455894359312 max-> 0.9266019164330878
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 44
model 44


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.6922510611638444 max-> 0.9100758396533044
accuracy =  0.9259259259259259
precision =  0.9154929577464789
recall =  1.0
f1 score =  0.9558823529411764
-------------------------------------
Re-starting from epoch 45
model 45


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.703730268189782 max-> 0.9155042381365548
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 46
model 46


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.7091236087268266 max-> 0.9234519104084321
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 47
model 47


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))

iou: 0.6998736433006947 max-> 0.9244337542593706
accuracy =  0.9135802469135802
precision =  0.9014084507042254
recall =  1.0
f1 score =  0.9481481481481481
-------------------------------------
Re-starting from epoch 48
model 48


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))