In [1]:
import os
import numpy as np
import skimage.io
import cv2
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib

import yaml
from PIL import Image
from tqdm.notebook import tqdm

from train import WeldingConfig, WeldingDataset, setpath

In [7]:
# Setting

# Create model object in inference mode.
config = WeldingConfig()
MODEL_DIR = os.path.join(os.getcwd(), "logs")
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
 
# Prepare testing data
image_path = './test_data/'
img_list = os.listdir(image_path + 'pic/')
dataset = WeldingDataset()
dataset.load_img(*setpath(image_path))
dataset.prepare()
total_imgs = len(dataset.image_info)
class_name = dataset.class_names

# Testing models index range
model_test_range=[35,60]

# The boundary of postive/negtive samples (=positive samples' number)
PN_boundary = 71

In [3]:
def cal_iou(gt_mask,pre_mask):
    inter = np.count_nonzero(gt_mask+pre_mask==2)
    union = np.count_nonzero(gt_mask!=pre_mask)+inter
    return inter/union if union != 0 else 0

In [8]:
import warnings
warnings.filterwarnings('ignore')

for i in range(*model_test_range):
    TP,FP,TN,FN = 0,0,0,0
    iou_list=[]
    
    #Load weights
    model_id = ('0' if i < 9 else '') + str(i + 1)
    MODEL_PATH = os.path.join(MODEL_DIR, "welding20220531T1122/mask_rcnn_welding_00" + model_id + ".h5")
    model.load_weights(MODEL_PATH, by_name=True)
        
    #Set tqdm
    print('model', model_id)
    progress = tqdm(total=total_imgs)
    
    for idx in range(total_imgs):
        name = img_list[idx]
        iou = 0
        # Load ground_truth's information
        path = image_path + 'pic/' + name# + '.png'
        gt_mask = dataset.load_mask(idx)[0]
        a=dataset.class_names
        gt_class_id = a.index(dataset.from_yaml_get_class(idx)[0])
        
        # Prdict
        image = skimage.io.imread(path)
        r = model.detect([image], verbose=0)[0]
        pre_mask = r['masks']
        
        # Classificate samples
        if gt_class_id in r['class_ids']:
            iou = 0
            for i, class_id in enumerate(r['class_ids']):
                if class_id == gt_class_id:
                    pre_mask = r['masks'][:,:,[i]]
                    iou = max(cal_iou(gt_mask,pre_mask),iou)
            iou_list.append(iou)
            if idx < PN_boundary:
                TP+=1
            else:
                TN+=1
        else:
            if idx < PN_boundary:
                FP+=1
            else:
                FN+=1
                
        progress.update(1)

    print('iou:',sum(iou_list)/len(iou_list),'max->',max(iou_list))
    print('accuracy = ', (TP+TN)/(TP+FN+FP+TN))
    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    print('precision = ', precision)
    print('recall = ', recall)
    print('f1 score = ', 2*precision*recall/(precision+recall))
    print('-------------------------------------')

Re-starting from epoch 36
model 36


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6685562175037028 max-> 0.8979956340543759
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 37
model 37


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6744329361981248 max-> 0.8928690994308671
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 38
model 38


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6801455444061455 max-> 0.8776802797387345
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 39
model 39


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6590368940404262 max-> 0.8549437537004144
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 40
model 40


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6762299011945432 max-> 0.9078956105983133
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 41
model 41


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6704323107723155 max-> 0.9063860667634253
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 42
model 42


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6623878224996956 max-> 0.8810682804894093
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 43
model 43


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6782023076262499 max-> 0.9198638208720702
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 44
model 44


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6649877782186114 max-> 0.8746732880292734
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 45
model 45


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6713845189531549 max-> 0.8958196398054423
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 46
model 46


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6722604703160835 max-> 0.9234519104084321
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 47
model 47


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6683083779829206 max-> 0.8875228877844624
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 48
model 48


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6651786631677842 max-> 0.8940959652471533
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 49
model 49


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6672077013203213 max-> 0.8757032578830303
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 50
model 50


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6705255053502851 max-> 0.8880235602094241
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 51
model 51


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6613321248587218 max-> 0.8636955107351985
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 52
model 52


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6616149152972637 max-> 0.8677063957647733
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 53
model 53


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6468169757487633 max-> 0.8490246701090075
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 54
model 54


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.663412291208455 max-> 0.8985753496274996
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 55
model 55


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6492988975463823 max-> 0.8587497853340202
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 56
model 56


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6565769545837771 max-> 0.8663951910025208
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 57
model 57


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6676963270667601 max-> 0.8992278497578851
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 58
model 58


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6653938913356555 max-> 0.8750408470034639
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 59
model 59


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6684867844923774 max-> 0.8824796084828711
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
Re-starting from epoch 60
model 60


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

iou: 0.6590296019868059 max-> 0.8569380173368963
accuracy =  0.8571428571428571
precision =  0.8571428571428571
recall =  1.0
f1 score =  0.923076923076923
-------------------------------------
