Algorithm and setup

In [0]:
import numpy as np
from scipy.signal import convolve2d


def convolve(num_convs, label_img, thresh=0.6):
    kernel = np.ones((3,3))/9.
    conv_label_img = (label_img - np.min(label_img))/np.ptp(label_img)
    conv_label_img = np.array(conv_label_img >= thresh).astype(int)
    for k in range(num_convs):
        conv_label_img = convolve2d(conv_label_img, kernel, mode='same')
        conv_label_img = np.array(conv_label_img >= thresh).astype(int)
    return conv_label_img


def get_center_list(conv_label_img, radius):
    size_x, size_y = conv_label_img.shape
    center_list = [[i, j] for i in range(size_x) for j in range(size_y)\
                  if conv_label_img[i,j]==1]

    new_center_list = []
    while len(center_list) > 0:
        avg_list = []
        i, j = center_list.pop(0)
        avg_list.append([i, j])

        for k in range(len(center_list)):
            ik, jk = center_list.pop(0)
            rsq = (i - ik)*(i - ik) + (j - jk)*(j - jk)
            if rsq < radius*radius:
                avg_list.append([ik, jk])
            else:
                center_list.append([ik, jk])

        avg_list = zip(*avg_list)
        i_new = int(round(np.mean(avg_list[0])))
        j_new = int(round(np.mean(avg_list[1])))
        new_center_list.append((i_new, j_new))

    #print len(new_center_list), "centers found"
    return new_center_list


def detect_diff(label_center, evals_center, radius=7.5):
    match_list = []
    label_list = list(label_center)
    evals_list = list(evals_center)
    
    for j in range(len(label_list)):
        (lx, ly) = label_list.pop(0)
        match_found = False
        for k in range(len(evals_list)):
            (ex, ey) = evals_list[k]
            rsq = (lx - ex)*(lx - ex) + (ly - ey)*(ly - ey)
            if rsq <= radius*radius:
                evals_list.pop(k)
                match_coord = ((lx + ex)/2, (ly + ey)/2)
                match_list.append(match_coord)
                match_found = True
                break
        if not match_found:
            label_list.append((lx, ly))
    return match_list, label_list, evals_list

def calculate_accuracy(label_file_list, evals_file_list, num_convs):

    label_img = process_label(label_file_list)[:,:,1]
    evals_img = process_label(evals_file_list)[:,:,1]

    conv_label_img = convolve(num_convs, label_img)
    conv_evals_img = convolve(num_convs, evals_img)

    conv_label_cen = get_center_list(conv_label_img, 7.5)
    conv_evals_cen = get_center_list(conv_evals_img, 7.5)

    match_list, label_list, evals_list = detect_diff(conv_label_cen, conv_evals_cen)
    
    TP = len(match_list)
    FP = len(evals_list)
    FN = len(label_list)
    TN = 304*16 - TP - FP - FN

    return TP ,FP, FN, TN

def get_acc_list(lbl_list, num_images, num_convs, label_dir, predi_dir, image_dir):
    acc_list = {}
    for lbl in lbl_list:
        acc_list[lbl] = []
        print lbl
        for i in range(num_images):
            print "\t", i
            label_file_list = [label_dir + 'label_' + lbl + '/image' + str(i) + '_label' + lbl + '.tiff']
            evals_file_list = [predi_dir + 'label_' + lbl + '/image' + str(i) + '_' + lbl + '_prediction.png']

            TP ,FP, FN, TN = calculate_accuracy(label_file_list, evals_file_list, num_convs)

            acc_list[lbl].append([TP, FP, FN, TN])

        acc_list[lbl] = zip(*acc_list[lbl])
    json.dump(acc_list, open(image_dir + "acc_list.json", 'w'))

Imports and parameter initialization

In [0]:
import sys
sys.path.insert(0, '../preprocessing')
from image_parse import *
import matplotlib.pyplot as plt
import json

In [0]:
parent_dir  = "/content/drive/My Drive/stem-learning/"
image_dir = parent_dir + "data/WSeTe/full_simulation_set/"
(label_dir, predi_dir) = [image_dir + 'label/', image_dir + 'prediction/']
lbl_list = ['2Te', 'Se', 'TeSe', 'vacancy']
num_convs = 2
num_images = 50

Evaluate all the simulated images by looking at the centers in the predictions and labels and matching them.

In [0]:
#get_acc_list(lbl_list, num_images, num_convs, label_dir, predi_dir, image_dir)

Let's look more closely into how this process happens. 
First, Let's load a label and prediction image:

In [0]:
from scipy.misc import imsave
i, lbl = 19, 'TeSe'
label_file_list = [label_dir + 'label_' + lbl + '/image' + str(i) + '_label' + lbl + '.tiff']
evals_file_list = [predi_dir + 'label_' + lbl + '/image' + str(i) + '_' + lbl + '_prediction.png']

label_img = process_label(label_file_list)[:,:,1]
evals_img = process_label(evals_file_list)[:,:,1]

imsave("1_ex_label.png", label_img)
imsave("1_ex_evals.png", evals_img)

Next, let's get the centers of these defects. To do this, we first normalize the image so that the range of the pixels is $[0,1]$. Next, we convolve the image with the kernel 

$\left(\begin{matrix}1&1&1\\1&1&1\\1&1&1\end{matrix}\right)$

Then we normalize again and threshold the pixels, setting values less than 0.6 to zero and values above 0.6 to 1. 

This process is done 2 times to decrease the size of the labels

In [0]:
kernel = np.ones((3,3))/9.
thresh = 0.6


#label
conv_label_img = (label_img - np.min(label_img))/np.ptp(label_img)
conv_label_img = np.array(conv_label_img >= thresh).astype(int)

#conv 1
conv_label_img = convolve2d(conv_label_img, kernel, mode='same')
imsave("2_ex_label_conv_11.png", conv_label_img)
conv_label_img = np.array(conv_label_img >= thresh).astype(int)
imsave("2_ex_label_conv_12.png", conv_label_img)

#conv 2
conv_label_img = convolve2d(conv_label_img, kernel, mode='same')
imsave("2_ex_label_conv_21.png", conv_label_img)
conv_label_img = np.array(conv_label_img >= thresh).astype(int)
imsave("2_ex_label_conv_22.png", conv_label_img)


#evals
conv_evals_img = (evals_img - np.min(evals_img))/np.ptp(evals_img)
conv_evals_img = np.array(conv_evals_img >= thresh).astype(int)

#conv 1
conv_evals_img = convolve2d(conv_evals_img, kernel, mode='same')
imsave("2_ex_evals_conv_11.png", conv_evals_img)
conv_evals_img = np.array(conv_evals_img >= thresh).astype(int)
imsave("2_ex_evals_conv_12.png", conv_evals_img)

#conv 2
conv_evals_img = convolve2d(conv_evals_img, kernel, mode='same')
imsave("2_ex_evals_conv_21.png", conv_evals_img)
conv_evals_img = np.array(conv_evals_img >= thresh).astype(int)
imsave("2_ex_evals_conv_22.png", conv_evals_img)

This process filters out any predictions that are not as confidently labeled. That is, the big dots are confident labels, while the small dots are not as confident.

Next we find the centers of these defects.

In [0]:
conv_label_cen = get_center_list(conv_label_img, 7.5)
conv_evals_cen = get_center_list(conv_evals_img, 7.5)

x_list, y_list = zip(*conv_label_cen)
fig = plt.figure(frameon=False)
fig.set_size_inches(20,20)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(label_img, cmap='gray')
ax.scatter(y_list, x_list, alpha = 0.5, color='y')
ax.set_xlim(0,1024)
ax.set_ylim(1024,0)
plt.savefig("3_ex_label_centers.png")
plt.close()

x_list, y_list = zip(*conv_evals_cen)
fig = plt.figure(frameon=False)
fig.set_size_inches(20,20)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(evals_img, cmap='gray')
ax.scatter(y_list, x_list, alpha = 0.5, color='y')
ax.set_xlim(0,1024)
ax.set_ylim(1024,0)
plt.savefig("3_ex_evals_centers.png")
plt.close()


Next we compare the centers from the labels and the predictions. 

In [0]:
match_list, label_list, evals_list = detect_diff(conv_label_cen, conv_evals_cen)

TP = len(match_list)
FN = len(label_list)
FP = len(evals_list)
TN = 304*16 - (TP + FN) - FP
print "True  Positives:   {}         False Positives:   {}".format(TP, FP)
print "True  Negatives:   {}         False Negatives:   {}".format(TN, FN)

True  Positives:   277         False Positives:   12
True  Negatives:   4570         False Negatives:   5


In [0]:
[mx_list, my_list] = [[],[]] if TP == 0 else zip(*match_list)
[lx_list, ly_list] = [[],[]] if FN == 0 else zip(*label_list)
[ex_list, ey_list] = [[],[]] if FP == 0 else zip(*evals_list)

fig = plt.figure(frameon=False)
fig.set_size_inches(20,20)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(np.zeros((1024, 1024)), cmap='gray')
ax.scatter(my_list, mx_list, label='True Positive')
ax.scatter(ly_list, lx_list, label='False Negative')
ax.scatter(ey_list, ex_list, label='False Positive')
ax.set_ylim(1024, 0)
ax.set_xlim(0,1024)
#ax.legend(loc='best')

plt.savefig("4_ex_defect_results.png")
plt.close()

In [0]:
def fmt(arr):
    return np.array((arr)).astype(np.float)

def print_results(recall, precision, F1, bal_acc):
    print "\t\trecall:                {:0.2f}%".format(recall*100)
    print "\t\tprecision:             {:0.2f}%".format(precision*100)
    print "\t\tF1 score:              {:0.2f}%".format(F1*100)
    print "\t\tbalanced accuracy:     {:0.2f}%".format(bal_acc*100)
    
def get_results(TP, FP, FN, TN):
    TNR = TN/(TN + FP)
    TPR = TP/(TP + FN)
    
    recall    = TP/(TP + FN)
    precision = TP/(TP + FP)
    F1        = 2*recall*precision/(recall + precision)
    bal_acc   = 0.5*(TNR + TPR)
    return recall, precision, F1, bal_acc

acc_list = json.load(open(image_dir + "acc_list.json", 'r'))
import matplotlib.pyplot as plt
for lbl in lbl_list:
    TP_list, FP_list, FN_list, TN_list = map(fmt, acc_list[lbl])
    
    TP_train, FP_train, FN_train, TN_train = map(sum, [TP_list[:3], FP_list[:3], FN_list[:3], TN_list[:3]])
    TP_test,  FP_test,  FN_test,  TN_test  = map(sum, [TP_list[3:], FP_list[3:], FN_list[3:], TN_list[3:]])
    TP_both,  FP_both,  FN_both,  TN_both  = map(sum, [TP_list,     FP_list,     FN_list,     TN_list])

    
    print lbl
    print "\tTrain"
    print_results(*get_results(TP_train, FP_train, FN_train, TN_train))
    print "\n\n"
    
    print "\tTest"
    print_results(*get_results(TP_test, FP_test, FN_test, TN_test))
    print "\n\n"
    
    print "\tTrain and Test"
    print_results(*get_results(TP_both, FP_both, FN_both, TN_both))




2Te
	Train
		recall:                100.00%
		precision:             100.00%
		F1 score:              100.00%
		balanced accuracy:     100.00%



	Test
		recall:                99.60%
		precision:             99.35%
		F1 score:              99.47%
		balanced accuracy:     99.77%



	Train and Test
		recall:                99.62%
		precision:             99.39%
		F1 score:              99.51%
		balanced accuracy:     99.79%
Se
	Train
		recall:                100.00%
		precision:             100.00%
		F1 score:              100.00%
		balanced accuracy:     100.00%



	Test
		recall:                99.56%
		precision:             99.89%
		F1 score:              99.73%
		balanced accuracy:     99.78%



	Train and Test
		recall:                99.59%
		precision:             99.90%
		F1 score:              99.74%
		balanced accuracy:     99.79%
TeSe
	Train
		recall:                100.00%
		precision:             100.00%
		F1 score:              100.00%
		balanced accuracy:     100.00%



