In [None]:
import os
import sys
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import cv2
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader
import h5py

WORK_DIR = Path(Path.cwd()).parent
sys.path.append(str(WORK_DIR))
from src.datasets import get_dataset, get_dataloader
from src.utils import parse_data_cfg, IMG, DATA_DIR, RHD

In [None]:
cfg = parse_data_cfg(WORK_DIR/'data_cfg'/'znb_handseg_data_mod2.cfg')
epoch = 200
exp_dir = cfg["exp_dir"]
data_split = 'test'
split_set = cfg[data_split + '_set']
img_size = int(cfg['img_size'])

# Visualize Dataloader

In [None]:
dataset_kwargs = {'split_set': split_set}
dataset   = get_dataset(cfg, dataset_kwargs)
sampler   = None
shuffle   = cfg['shuffle']
kwargs = {'batch_size'  :   int(cfg['batch_size']),
          'shuffle'     :   shuffle,
          'num_workers' :   int(cfg['num_workers']),
          'pin_memory'  :   True}
data_loader = get_dataloader(dataset, sampler, kwargs)


In [None]:
idx = 0
for i, (img, mask) in enumerate(data_loader):
    if i == idx:
        batch_size = img.shape[0]
        img = img.cpu().numpy()
        img = np.swapaxes(img, 2, 3)
        img = np.swapaxes(img, 1, 3)
        img = IMG.scale_img_255(img)
        mask = mask.cpu().numpy()
        mask = np.swapaxes(mask, 2, 3)
        mask = np.swapaxes(mask, 1, 3)
        break
    i += 1

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(15, 15))
idx = 0
for i in range(4):
    for j in range(4):
        cur_img = img[idx]
        cur_mask = mask[idx]
        show_mask = cur_mask[:, :, 1]
        show_mask = np.stack([np.zeros(show_mask.shape), np.zeros(show_mask.shape), show_mask], 2).astype('uint8')
        show_mask = IMG.scale_img_255(show_mask)
        show_img = show_mask*0.5 + cur_img*0.5
        show_img = IMG.scale_img_255(show_img)
        ax[i, j].imshow(show_img)
        idx += 1

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(15, 15))
idx = 0
for i in range(4):
    for j in range(4):
        cur_img = img[idx]
        cur_mask = mask[idx]
        show_mask = cur_mask[:, :, 0]
        show_mask = np.stack([np.zeros(show_mask.shape), np.zeros(show_mask.shape), show_mask], 2).astype('uint8')
        show_mask = IMG.scale_img_255(show_mask)
        show_img = show_mask*0.5 + cur_img*0.5
        show_img = IMG.scale_img_255(show_img)
        ax[i, j].imshow(show_img)
        idx += 1

# Evaluation

In [None]:
pred_file = os.path.join(DATA_DIR, exp_dir, 'predict_{}_{}_mask.h5'.format(epoch, data_split))
f = h5py.File(pred_file, 'r')
pred_mask = f['mask'][...]
f.close()

In [None]:
idx = 0

image = Image.open(os.path.join(DATA_DIR, 'RHD_published_v2', split_set, 'color', '%.5d.png' % idx))
image = np.asarray(image.resize((img_size, img_size)))

mask_gt = Image.open(os.path.join(DATA_DIR, 'RHD_published_v2', split_set, 'mask', '%.5d.png' % idx))
mask_gt = np.asarray(mask_gt.resize((img_size, img_size)))
hand_mask           = np.greater(mask_gt, 1)
bg_mask             = np.logical_not(hand_mask)
total_mask          = np.stack([bg_mask, hand_mask], 2).astype('float32')
show_mask_gt        = total_mask[:, :, 1].copy()
show_mask_gt[show_mask_gt > 0] = 1
show_mask_gt[show_mask_gt <= 0] = 0
show_mask_gt_eval = show_mask_gt.copy()
show_mask_gt = np.stack([np.zeros(show_mask_gt.shape), np.zeros(show_mask_gt.shape), show_mask_gt], 2).astype('uint8')
show_mask_gt = IMG.scale_img_255(show_mask_gt)

fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
show_mask = pred_mask[idx][:, :, 1].copy()
show_mask[show_mask > 0] = 1
show_mask[show_mask <= 0] = 0
show_mask_eval = show_mask.copy()
show_mask = np.stack([show_mask, np.zeros(show_mask.shape), np.zeros(show_mask.shape)], 2).astype('uint8')
show_mask = IMG.scale_img_255(show_mask)
show_img = show_mask*0.5 + show_mask_gt*0.5 
show_img = IMG.scale_img_255(show_img)
show_img = show_img*0.5 + image*0.5
show_img = IMG.scale_img_255(show_img)
ax.imshow(show_img)

intersection = np.logical_and(show_mask_gt_eval, show_mask_eval)
union = np.logical_or(show_mask_gt_eval, show_mask_eval)
iou_score = np.sum(intersection)/np.sum(union)
print(iou_score)

In [None]:
iou_list = []
for i in tqdm(range(pred_mask.shape[0])):
    mask_gt = Image.open(os.path.join(DATA_DIR, 'RHD_published_v2', split_set, 'mask', '%.5d.png' % i))
    mask_gt = np.asarray(mask_gt.resize((img_size, img_size)))
    hand_mask           = np.greater(mask_gt, 1)
    bg_mask             = np.logical_not(hand_mask)
    total_mask          = np.stack([bg_mask, hand_mask], 2).astype('float32')
    mask_gt             = total_mask[:, :, 1].copy()
    mask_gt[mask_gt > 0] = 1
    mask_gt[mask_gt <= 0] = 0
    
    cur_pred_mask = pred_mask[i][:, :, 1].copy()
    cur_pred_mask[cur_pred_mask > 0] = 1
    cur_pred_mask[cur_pred_mask <= 0] = 0
    
    intersection = np.logical_and(mask_gt, cur_pred_mask)
    union = np.logical_or(mask_gt, cur_pred_mask)
    iou_score = np.sum(intersection)/np.sum(union)
    iou_list.append(iou_score)

print('Avg_IOU:', np.mean(iou_list))
print('Worst mask id:', np.argmin(iou_list))
for idx in np.argsort(iou_list):
    print(idx, iou_list[idx])

In [None]:
# one_box, one_mask = RHD.get_hand_from_pred(pred_mask[idx])
# fig, ax = plt.subplots(figsize=(10, 10))
# ax.imshow(one_mask)
# RHD.draw_bbox(ax, one_box, (1,1))