In [1]:
import os
import pandas as pd
import torch
import numpy as np
from cv2 import distanceTransform, DIST_L2, DIST_MASK_PRECISE
from tqdm import tqdm
from PIL import Image
import torch.utils.data as data
from sklearn.metrics import confusion_matrix
from multiprocessing import Pool

import cv2
from matplotlib import pyplot as plt

In [2]:
predPath = '../inferences/knet/test/tta/'
gtPath = '../../Dataset/Labels/'

In [3]:
def computeBConfMat(pred, gt, nClasses, distances):
    confMatPerDist = np.zeros((len(distances), nClasses+1, nClasses+1), dtype=np.int64)
    if np.any(np.isinf(distances)):
        confMatPerDist[-1] += confusion_matrix(gt.flatten(), pred.flatten(), labels=np.arange(nClasses+1))
    
    distances = [int(d) for d in distances if not np.isinf(d)]
    if len(distances) == 0:
        return confMatPerDist
    
    predOneHot = torch.nn.functional.one_hot(pred, 4).permute(2,0,1).numpy().astype(np.uint8)
    gtOneHot = torch.nn.functional.one_hot(gt, 4).permute(2,0,1).numpy().astype(np.uint8)
    distPredOneHot = np.array([cv2.distanceTransform(binImg, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) for binImg in predOneHot])
    distGtOneHot = np.array([cv2.distanceTransform(binImg, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) for binImg in gtOneHot])
    
    for i, d in enumerate(distances):
        boundaryPredOneHot = np.bitwise_and(distPredOneHot <= d, predOneHot)
        boundaryGtOneHot = np.bitwise_and(distGtOneHot <= d, gtOneHot)
        # concatenate np.ones_like(boundaryPredOneHot[0]) at the end of the array to account for the background class
        boundaryPredOneHot = np.concatenate((boundaryPredOneHot, np.ones_like(boundaryPredOneHot[0])[np.newaxis]), axis=0)
        boundaryGtOneHot = np.concatenate((boundaryGtOneHot, np.ones_like(boundaryGtOneHot[0])[np.newaxis]), axis=0)
        boundaryPred = np.argmax(boundaryPredOneHot, axis=0)
        boundaryGt = np.argmax(boundaryGtOneHot, axis=0)
        
        confMatPerDist[i] += confusion_matrix(boundaryGt.flatten(), boundaryPred.flatten(), labels=np.arange(nClasses+1))

    return confMatPerDist

In [4]:
distances = [1,3,5,10,np.inf]
nClasses = 4

In [8]:
# imgFileNames = os.listdir(predPath)
# totalConfMatPerDist = np.zeros((len(distances), nClasses+1, nClasses+1), dtype=np.int64)
# for imgFileName in tqdm(imgFileNames):
#     pred = torch.from_numpy(np.array(Image.open(predPath + imgFileName))).long()
#     gt = torch.from_numpy(np.array(Image.open(gtPath + imgFileName))).long()
#     totalConfMatPerDist += computeBConfMat(pred, gt, nClasses, distances)

# Parallelize the above code using Pool
imgFileNames = os.listdir(predPath)[:50]
totalConfMatPerDist = np.zeros((len(distances), nClasses+1, nClasses+1), dtype=np.int64)
def computeBConfMatWrapper(imgFileName):
    pred = torch.from_numpy(np.array(Image.open(predPath + imgFileName))).long()
    gt = torch.from_numpy(np.array(Image.open(gtPath + imgFileName))).long()
    return computeBConfMat(pred, gt, nClasses, distances)

print('Computing boundary confusion matrices...')
with Pool(8) as p:
    confMats = list(tqdm(p.imap_unordered(computeBConfMatWrapper, imgFileNames), total=len(imgFileNames)))
    for confMat in confMats:
        totalConfMatPerDist += confMat

confMats = np.sum(confMats, axis=0)
print(totalConfMatPerDist)

Computing boundary confusion matrices...


100%|██████████| 50/50 [00:10<00:00,  4.90it/s]

[[[    75094     16187      3085      2116     87490]
  [    33566     89844       177      5995     46064]
  [     3248       159     10940      2525     11439]
  [    33930     13480      2092     12217     28495]
  [    38626     53959     10133      8764 103090375]]

 [[   406423     20421      4057      3140    117820]
  [    49436    413802       687      9654     49284]
  [     5621       631     54902      3939     16610]
  [    49244     15968      3112     49032     18770]
  [    42839     68303     14728      6150 102255427]]

 [[   763419     21628      4415      3375    126856]
  [    54804    748849       956     10747     50353]
  [     6235      1019     98531      4335     20847]
  [    50296     16417      3320     75339     15395]
  [    48185     73435     17848      4032 101459364]]

 [[  1666999     24363      5267      3503    138363]
  [    68890   1572832      1132     11480     43462]
  [     8016      1347    194428      4701     27886]
  [    50733     16755




In [None]:
# Save the confusion matrices
np.save('boundary_confusion_matrices.npy', totalConfMatPerDist)