In [1]:
import os
import numpy as np
import skimage.io
import cv2
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib

import yaml
from PIL import Image
from tqdm.notebook import tqdm

from train import WeldingConfig, WeldingDataset, setpath

In [2]:
# Setting

# Create model object in inference mode.
config = WeldingConfig()
MODEL_DIR = os.path.join(os.getcwd(), "logs")
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
 
# Prepare testing data
image_path = './total_data-2/'
img_list = os.listdir(image_path + 'pic/')
dataset = WeldingDataset()
dataset.load_img(*setpath(image_path))
dataset.prepare()
total_imgs = len(dataset.image_info)
class_name = dataset.class_names

# Testing models index range
model_test_range=[40,60]

# The boundary of postive/negtive samples (=positive samples' number)
PN_boundary = 51

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Instructions for updating:
Use fn_output_signature instead


In [3]:
def cal_iou(gt_mask,pre_mask):
    inter = np.count_nonzero(gt_mask+pre_mask==2)
    union = np.count_nonzero(gt_mask!=pre_mask)+inter
    return inter/union if union != 0 else 0

In [4]:
import warnings
warnings.filterwarnings('ignore')

for i in range(*model_test_range):
    TP,FP,TN,FN = 0,0,0,0
    iou_list=[]
    
    #Load weights
    model_id = ('0' if i < 9 else '') + str(i + 1)
    MODEL_PATH = os.path.join(MODEL_DIR, "welding20220531T1122/mask_rcnn_welding_00" + model_id + ".h5")
    model.load_weights(MODEL_PATH, by_name=True)
        
    #Set tqdm
    print('model', model_id)
    progress = tqdm(total=total_imgs)
    
    for idx in range(total_imgs):
        name = img_list[idx]
        iou = 0
        # Load ground_truth's information
        path = image_path + 'pic/' + name# + '.png'
        gt_mask = dataset.load_mask(idx)[0]
        a=dataset.class_names
        gt_class_id = a.index(dataset.from_yaml_get_class(idx)[0])
        
        # Prdict
        image = skimage.io.imread(path)
        r = model.detect([image], verbose=0)[0]
        pre_mask = r['masks']
        
        # Classificate samples
        if gt_class_id in r['class_ids']:
            iou = 0
            for i, class_id in enumerate(r['class_ids']):
                if class_id == gt_class_id:
                    pre_mask = r['masks'][:,:,[i]]
                    iou = max(cal_iou(gt_mask,pre_mask),iou)
            iou_list.append(iou)
            if idx < PN_boundary:
                TP+=1
            else:
                TN+=1
        else:
            if idx < PN_boundary:
                FP+=1
            else:
                FN+=1
                
        progress.update(1)

    print('iou:',sum(iou_list)/len(iou_list),'max->',max(iou_list))
    print('accuracy = ', (TP+TN)/(TP+FN+FP+TN))
    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    print('precision = ', precision)
    print('recall = ', recall)
    print('f1 score = ', 2*precision*recall/(precision+recall))
    print('-------------------------------------')

Re-starting from epoch 41
model 41


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8411796685699329 max-> 0.9070623916811091
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 42
model 42


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8275564705970562 max-> 0.9159414556962026
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 43
model 43


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8414633115775678 max-> 0.9266019164330878
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 44
model 44


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8315722297838138 max-> 0.9100758396533044
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 45
model 45


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8370570805805593 max-> 0.9155709804445038
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 46
model 46


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8425731886043168 max-> 0.9234519104084321
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 47
model 47


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8338260174382149 max-> 0.9244337542593706
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 48
model 48


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8320135109311478 max-> 0.9128431308011322
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 49
model 49


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8335570524390232 max-> 0.9116161616161617
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 50
model 50


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8349081291798179 max-> 0.921994884910486
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 51
model 51


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8307105325964012 max-> 0.9197570332480819
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 52
model 52


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8410164378305737 max-> 0.921146953405018
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 53
model 53


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8230898063059278 max-> 0.9206816747317483
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 54
model 54


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8303900061395805 max-> 0.9218624025799517
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 55
model 55


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8296139076484043 max-> 0.9143710587454364
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 56
model 56


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8290226153311614 max-> 0.9264874334613571
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 57
model 57


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8337604236762394 max-> 0.9236143768894861
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 58
model 58


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8351931432591745 max-> 0.9118492494410732
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 59
model 59


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8346610722397118 max-> 0.9180635644361003
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
Re-starting from epoch 60
model 60


HBox(children=(FloatProgress(value=0.0, max=61.0), HTML(value='')))

iou: 0.8301948924108714 max-> 0.9147615937295885
accuracy =  1.0
precision =  1.0
recall =  1.0
f1 score =  1.0
-------------------------------------
