In [1]:
from IPython.display import clear_output
import logging
import os
import shutil
import sys
import time
import tempfile
from glob import glob
from tqdm import tqdm
import pickle

import pandas as pd
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import (
    ITKReader,
    NumpyReader,
)
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Flipd,
    Lambdad,
    LoadImaged,
    RandAdjustContrastd, #check whether necessary
    RandFlipd,
    RandAffined,
    Resize,
    Resized,
    Rotate90d,
    ScaleIntensity,
    ScaleIntensityd,
    ToNumpy,
    ToTensor,
    ToTensord,
)
from monai.utils import first

from matplotlib import pylab as plt
from skimage.io import imread

import itk

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is", device)

MONAI version: 1.0.1
Numpy version: 1.22.4
Pytorch version: 1.13.0
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8271a193229fe4437026185e218d5b06f7c8ce69
MONAI __file__: E:\Users\BerkOlcay\anaconda3\envs\DL\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.8
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.2.0
Tensorboard version: 2.11.0
gdown version: 4.6.0
TorchVision version: 0.14.0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.4
pandas version: 1.5.2
einops version: 0.4.1
transformers version: 4.24.0
mlflow version: 1.30.0
pynrrd version: 0.4.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

Device is cuda


# Load data to be classified

In [2]:

def draw_segmented_area(frame_rgb, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked, ax=None):
    # Plot segmented area
    #fig, ax = plt.subplots(figsize=(3.2,2.4))
    alpha=0.4
    fig_created = False
    if ax is None:
        fig = plt.figure(figsize=(6.4,4.8))
        #canvas = FigureCanvas(fig)
        ax = fig.subplots()
        fig_created = True
    ax.imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax.imshow(np.ma.masked_where(visible_map_masked<0.5,visible_map_masked), cmap="spring", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(iris_map_masked<0.5,iris_map_masked), cmap="GnBu", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(pupil_map_masked<0.5,pupil_map_masked), cmap="OrRd", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(glints_map_masked<0.5,glints_map_masked), cmap="cool", vmax=1, vmin=0, alpha=alpha)
    ax.axis('off')
    if fig_created:
        fig.tight_layout()
        fig.canvas.draw()
        plt.show()
        
def draw_segmented_areas_separately(frame_rgb, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked, ax=None):
    # Plot segmented area
    #fig, ax = plt.subplots(figsize=(3.2,2.4))
    alpha=0.4
    fig_created = False
    if ax is None:
        fig = plt.figure(figsize=(32, 24))
        #canvas = FigureCanvas(fig)
        ax = fig.subplots(1,4)
        fig_created = True
        
    ax[0].imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax[0].imshow(np.ma.masked_where(visible_map_masked<0.5,visible_map_masked), cmap="spring", vmax=1, vmin=0)
    ax[0].axis('off')
    
    ax[1].imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax[1].imshow(np.ma.masked_where(iris_map_masked<0.5,iris_map_masked), cmap="GnBu", vmax=1, vmin=0)
    ax[1].axis('off')
    
    ax[2].imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax[2].imshow(np.ma.masked_where(pupil_map_masked<0.5,pupil_map_masked), cmap="OrRd", vmax=1, vmin=0)
    ax[2].axis('off')
    
    ax[3].imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax[3].imshow(np.ma.masked_where(glints_map_masked<0.5,glints_map_masked), cmap="cool", vmax=1, vmin=0)
    ax[3].axis('off')
    if fig_created:
        fig.tight_layout()
        fig.canvas.draw()
        plt.show()
        

pn_code = 'E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\PYTHON'
pn_data = 'E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\data\\data_dv3d_monai_QA'
pn_classifcation = 'E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\data\\data_dv3d_monai_QA\\pickle_open_close'

df = pd.read_csv(os.path.join(pn_code,'df_dv3d_monai_files.csv'), index_col=0)
df['fn_seg_maps_np'] = [s.replace('.pkl','.pkl') for s in df.fn_seg_maps]

# set up dataset splits and dict-lists
check_idxs = np.arange(0,df.shape[0])

check_files = [{"img": os.path.join(pn_data, fn_img), "seg": os.path.join(pn_data, fn_seg)} for fn_img, fn_seg in zip(df.fn_img[check_idxs], df.fn_seg_maps_np[check_idxs])]

print(f'df.columns:\n {df.columns.tolist()}')
df.head()

df.columns:
 ['fn_img', 'fn_qa_img', 'fn_annotation', 'fn_seg_maps', 'tag_dataset', 'fn_seg_maps_np']


Unnamed: 0,fn_img,fn_qa_img,fn_annotation,fn_seg_maps,tag_dataset,fn_seg_maps_np
0,12451_ubiris2_C107_S1_I7_000000.tiff,12451_ubiris2_C107_S1_I7_000000_seg_qa.png,12451_ubiris2_C107_S1_I7_000000.txt,12451_ubiris2_C107_S1_I7_000000_seg_maps.pkl,ubiris2,12451_ubiris2_C107_S1_I7_000000_seg_maps.pkl
1,12452_ubiris2_C133_S1_I4_000000.tiff,12452_ubiris2_C133_S1_I4_000000_seg_qa.png,12452_ubiris2_C133_S1_I4_000000.txt,12452_ubiris2_C133_S1_I4_000000_seg_maps.pkl,ubiris2,12452_ubiris2_C133_S1_I4_000000_seg_maps.pkl
2,12453_ubiris2_C79_S2_I2_000000.tiff,12453_ubiris2_C79_S2_I2_000000_seg_qa.png,12453_ubiris2_C79_S2_I2_000000.txt,12453_ubiris2_C79_S2_I2_000000_seg_maps.pkl,ubiris2,12453_ubiris2_C79_S2_I2_000000_seg_maps.pkl
3,12454_ubiris2_C390_S1_I15_000000.tiff,12454_ubiris2_C390_S1_I15_000000_seg_qa.png,12454_ubiris2_C390_S1_I15_000000.txt,12454_ubiris2_C390_S1_I15_000000_seg_maps.pkl,ubiris2,12454_ubiris2_C390_S1_I15_000000_seg_maps.pkl
4,12455_ubiris2_C85_S1_I1_000000.tiff,12455_ubiris2_C85_S1_I1_000000_seg_qa.png,12455_ubiris2_C85_S1_I1_000000.txt,12455_ubiris2_C85_S1_I1_000000_seg_maps.pkl,ubiris2,12455_ubiris2_C85_S1_I1_000000_seg_maps.pkl


In [3]:
# define transforms for image and segmentation
img_size = np.array([240,320])
rot_max = 45*np.pi/180.0
shear_max = 0.5
trans_max = tuple((img_size*0.15).astype(int))
scale_max = 0.25

def gray2rgb(x):
    #print(x.shape)
    if x.shape[0]==1:
        x = x.repeat(3, 1, 1)
        x.meta['original_channel_dim'] = -1 # THIS is the important line! 
    #print(x.shape)
    return x

def clean_tiff_meta(x):
    for key in ['DocumentName', 'ImageDescription', 'Software']:
        if key in x.meta.keys():
            del x.meta[key]
    return x

train_transforms = Compose(
    [
        #Lambdad(keys=['img', 'seg'], func=lambda x: print(x), overwrite = False),
        LoadImaged(keys=["img"], reader= ITKReader, image_only = True),
        LoadImaged(keys=["seg"], reader=NumpyReader, image_only = True),
        EnsureChannelFirstd(keys=["img"]),
        Lambdad(keys=['img'], func=lambda x: gray2rgb(x)), # gray to rgb conversion
        ScaleIntensityd(keys="img"),        
        Flipd(keys=["seg"], spatial_axis=[1]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["seg"]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["img", "seg"]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["img", "seg"]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["img", "seg"]), # necessary due to various readers ITKReader and NumpyReader
        Resized(keys=["img", "seg"], spatial_size=(240,320)),
        RandAdjustContrastd(keys=["seg"], prob=1.0, gamma=(0.1, 10.0)),
        EnsureTyped(keys="img"),
        Lambdad(keys=['img'], func=lambda x: clean_tiff_meta(x)), # clean weird keys in TIFF metadata - turns out this is not necessary
        ToTensord(keys=["img", "seg"]),
    ]
)

# classify images and check whether segmentations are correct in the meanwhile

In [4]:
# add them to dataset, and data loader
npc = ToNumpy()
check_batch_size = 5506
check_ds = monai.data.Dataset(data=check_files, transform=train_transforms)
check_loader = DataLoader(
    check_ds, 
    batch_size=check_batch_size, 
    shuffle=False,
    num_workers=0)
check_data = first(check_loader)

In [7]:
# related to segmentation check 
checkSegmentations = []
extremeCases = []
notPrecise = []

# if there are pickles, load them and resume to check
pn_checkSegmentationsPkl = os.path.join(pn_code, "checkSegmentations.pkl")
pn_extremeCasesPkl = os.path.join(pn_code, "extremeCases.pkl")
pn_notPrecisePkl = os.path.join(pn_code, "notPrecise.pkl")

if os.path.isfile(pn_checkSegmentationsPkl):
    checkSegmentations = pickle.load( open(pn_checkSegmentationsPkl, "rb" ))
if os.path.isfile(pn_extremeCasesPkl):
    extremeCases = pickle.load( open(pn_extremeCasesPkl, "rb" ))
if os.path.isfile(pn_notPrecisePkl):
    extremeCases = pickle.load( open(pn_notPrecisePkl, "rb" ))

In [34]:
# classify images whether the eye is open or closed
# also control segmentations whether they're correct
filename = ""
index = 0
while index < check_batch_size:
    img_cf = np.squeeze(npc(check_data["img"])[index,:,:,:])
    seg_cf = np.squeeze(npc(check_data["seg"])[index,:,:,:])
    
    # channel last versions for plotting
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(seg_cf[c,:,:]) for c in [0,1,2,3]])
    img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    draw_segmented_areas_separately(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    
    previousFilename = filename
    filename = check_files[index]["img"]
    changed_folder = filename.replace('\data_dv3d_monai_QA','\data_dv3d_monai_QA\pickle_open_close')
    without_extension = os.path.splitext(changed_folder)[0]
    new_filename = without_extension +'_classification.pkl'
    print ("image index ", index, ": ", filename.split('\\')[-1])
    
    a = -1
    while (a not in range(7)):
        print("Enter a number")
        print("0 go back")
        print("1 open")
        print("2 open not precise")
        print("3 open, check segmentation")
        print("4 open, extreme cases")
        print("5 close")
        print("6 close, check segmentation")
        a = int(input("Enter a number: "))
        if (a == 0 and index != 0): 
            index = index - 2
            if (previousFilename in checkSegmentations):
                checkSegmentations.remove(previousFilename)
            if (previousFilename in extremeCases):
                extremeCases.remove(previousFilename)
            if (previousFilename in notPrecise):
                notPrecise.remove(previousFilename)
        elif (a == 1):
            ylabel = np.ones(1)
        elif (a == 2): 
            ylabel = np.ones(1)
            notPrecise.append(filename)
        elif (a == 3):
            ylabel = np.ones(1)
            checkSegmentations.append(filename)
        elif (a == 4):
            ylabel = np.ones(1)
            extremeCases.append(filename)
        elif (a == 5):
            ylabel = np.zeros(1)
        elif (a == 6):
            ylabel = np.zeros(1)
            checkSegmentations.append(filename)
        else:
            print ("Wrong input. Enter a number")
            print("0 go back")
            print("1 open")
            print("2 open not precise")
            print("3 open, check segmentation")
            print("4 open, extreme cases")
            print("5 close")
            print("6 close, check segmentation")
        
    fileObject = open(new_filename, 'wb')
    pickle.dump(ylabel, fileObject)
    fileObject.close()
    print(new_filename, " updated.")
    
    clear_output(wait=False)
    index = index + 1

Elapsed time: 273169.59 sec for 20 images


In [117]:
# update excel table for the classifications
def convertFromImgToClassification(imgName):
    without_extension = os.path.splitext(imgName)[0]
    classification_filename = without_extension +'_classification.pkl'
    target_file = os.path.join("pickle_open_close", classification_filename)
    return target_file

dff = pd.read_csv(os.path.join(pn_code,'df_dv3d_oc_monai_files.csv'), index_col=0)
dff['fn_cls'] = [convertFromImgToClassification(s) for s in df.fn_seg_maps]
dff.to_csv('df_dv3d_oc_monai_files.csv', index= True)

# print sizes of the arrays and save them

In [55]:
print('checkSegmentations ', len(checkSegmentations))
#print(checkSegmentations)
print('extremeCases ', len(extremeCases))
#print(extremeCases)
print('notPrecise ', len(notPrecise))
#print(notPrecise)

checkSegmentations  178
extremeCases  139
notPrecise  154


In [56]:
# save checkSegmentations
fileObject = open(pn_checkSegmentationsPkl, 'wb')
pickle.dump(checkSegmentations, fileObject)
fileObject.close()
print(pn_checkSegmentationsPkl, " saved.")

# save extremeCases
fileObject = open(pn_extremeCasesPkl, 'wb')
pickle.dump(extremeCases, fileObject)
fileObject.close()
print(pn_extremeCasesPkl, " saved.")

# save checkSegmentations
fileObject = open(pn_notPrecisePkl, 'wb')
pickle.dump(notPrecise, fileObject)
fileObject.close()
print(pn_notPrecisePkl, " saved.")

E:\Users\BerkOlcay\DeepVOG3D\DeepVOG3D\PYTHON\checkSegmentations.pkl  saved.
E:\Users\BerkOlcay\DeepVOG3D\DeepVOG3D\PYTHON\extremeCases.pkl  saved.
E:\Users\BerkOlcay\DeepVOG3D\DeepVOG3D\PYTHON\notPrecise.pkl  saved.


In [None]:
#Remove 05464_dsgz3_P32_R.mp4_112078.png

# double check the arrays (optional)

In [None]:
checkSegmentationsFiles = []
for file in check_files:
    if (file['img'] in checkSegmentations):
        checkSegmentationsFiles.append(file)
print(checkSegmentationsFiles[:3])
print(len(checkSegmentationsFiles))

In [65]:
check_batch_size = 368
check_ds = monai.data.Dataset(data=checkSegmentationsFiles, transform=train_transforms)
check_loader = DataLoader(
    check_ds, 
    batch_size=check_batch_size, 
    shuffle=False,
    num_workers=0)
check_data = first(check_loader)

In [None]:
#index = 0
while index < check_batch_size:
    img_cf = np.squeeze(npc(check_data["img"])[index,:,:,:])
    seg_cf = np.squeeze(npc(check_data["seg"])[index,:,:,:])
    
    # channel last versions for plotting
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(seg_cf[c,:,:]) for c in [0,1,2,3]])
    img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    draw_segmented_areas_separately(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    
    filename = checkSegmentationsFiles[index]["img"]
    print ("image index ", index, ": ", filename.split('\\')[-1])
    
    a = int(input("Enter a number: "))
    if (a == 2):
        checkSegmentations.remove(filename)
        
    clear_output(wait=False)
    index = index + 1
    

In [63]:
print(len(checkSegmentations))
print (filename in checkSegmentations)

139
False


# copy the images and segmentation comprasions to another folder

In [80]:
print('checkSegmentations ', len(notPrecise))
pn_targetFile = os.path.join(pn_data,'1check_segmentations\\')

for file_name in checkSegmentations:
    without_extension = os.path.splitext(file_name)[0]
    seg_qa_file_name = without_extension +'_seg_qa.png'
    
    shutil.copy(file_name, file_name.replace(pn_data,pn_targetFile))
    shutil.copy(seg_qa_file_name, seg_qa_file_name.replace(pn_data,pn_targetFile))

checkSegmentations  178


In [81]:
print('checkSegmentations ', len(notPrecise))
pn_targetFile = os.path.join(pn_data,'2not_precise\\')

for file_name in notPrecise:
    without_extension = os.path.splitext(file_name)[0]
    seg_qa_file_name = without_extension +'_seg_qa.png'
    
    shutil.copy(file_name, file_name.replace(pn_data,pn_targetFile))
    shutil.copy(seg_qa_file_name, seg_qa_file_name.replace(pn_data,pn_targetFile))

checkSegmentations  154


In [None]:
print('checkSegmentations ', len(extremeCases))
pn_targetFile = os.path.join(pn_data,'3extreme_cases\\')

for file_name in extremeCases:
    without_extension = os.path.splitext(file_name)[0]
    seg_qa_file_name = without_extension +'_seg_qa.png'
    
    shutil.copy(file_name, file_name.replace(pn_data,pn_targetFile))
    shutil.copy(seg_qa_file_name, seg_qa_file_name.replace(pn_data,pn_targetFile))