In [None]:
!nvidia-smi

In [None]:
import logging
import os
import shutil
import sys
import time
import tempfile
from glob import glob
from tqdm import tqdm
import pickle

import pandas as pd
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from matplotlib import pylab as plt
from skimage.io import imread

import monai
from monai.networks.layers import Norm
from monai.data import (
    list_data_collate,
    ITKReader,
    NumpyReader,
    decollate_batch
)
from monai.inferers import SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    AddChanneld,
    RandAdjustContrastd,
    Compose,
    DivisiblePadd,
    DataStatsd,
    EnsureChannelFirstd,
    EnsureTyped,
    Flipd,
    Lambdad,
    LoadImaged,
    RandAdjustContrastd,
    RandCropByPosNegLabeld,
    RandRotated,
    RandZoomd,
    RandFlipd,
    RandShiftIntensityd,
    RandScaleIntensityd,
    RandAffined,
    Rand3DElasticd,
    RandGaussianNoised,
    Resize,
    Resized,
    ScaleIntensity,
    ScaleIntensityd,
    SpatialPadd,
    ToTensor,
    ToTensord,
    ToNumpyd,
    ToNumpy,
    DataStats,
    Rotate90d,
)
from monai.utils import first, set_determinism
from monai.visualize import plot_2d_or_3d_image
import itk

# Utility functions
def np_sigmoid(x):
    return 1/(1 + np.exp(-x)) 

def arrayStats(arr):
    vec = arr.ravel()
    stats = {'min':np.min(vec)}
    stats['max'] = np.max(vec)
    stats['mean'] = np.mean(vec)
    stats['perc1'] = np.percentile(vec,1.0)
    stats['perc5'] = np.percentile(vec,5.0)
    stats['perc25'] = np.percentile(vec,25.0)
    stats['perc50'] = np.percentile(vec,50.0)
    stats['perc75'] = np.percentile(vec,75.0)
    stats['perc95'] = np.percentile(vec,95.0)
    stats['perc99'] = np.percentile(vec,99.0)
    stats['median'] = np.median(vec)
    return stats

def draw_segmented_area(frame_rgb, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked, ax=None):
    # Plot segmented area
    #fig, ax = plt.subplots(figsize=(3.2,2.4))
    alpha=0.4
    fig_created = False
    if ax is None:
        fig = plt.figure(figsize=(3.2,2.4))
        #canvas = FigureCanvas(fig)
        ax = fig.subplots()
        fig_created = True
    ax.imshow(frame_rgb)#, vmax=1, vmin=0, cmap="gray")
    ax.imshow(np.ma.masked_where(visible_map_masked<0.5,visible_map_masked), cmap="spring", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(iris_map_masked<0.5,iris_map_masked), cmap="GnBu", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(pupil_map_masked<0.5,pupil_map_masked), cmap="OrRd", vmax=1, vmin=0, alpha=alpha)
    ax.imshow(np.ma.masked_where(glints_map_masked<0.5,glints_map_masked), cmap="cool", vmax=1, vmin=0, alpha=alpha)
    ax.axis('off')
    if fig_created:
        fig.tight_layout()
        fig.canvas.draw()
        plt.show()

def draw_segmented_area_4d(frame_rgb, seg_4d, ax=None):
    pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked = tuple([np.squeeze(seg_4d[:,:,c]) for c in [0,1,2,3]])
    draw_segmented_area(frame_rgb, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked, ax=ax)

pn_code = 'E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\PYTHON'
pn_data = 'E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\data\\data_dv3d_monai_QA'

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

df = pd.read_csv(os.path.join(pn_code,'df_dv3d_monai_files.csv'), index_col=0)

# set up dataset splits and dict-lists
nr_train_samples = 4000
train_idxs = np.arange(nr_train_samples)
valid_idxs = np.arange(nr_train_samples,df.shape[0])

train_files = [{"img": os.path.join(pn_data, fn_img), "seg": os.path.join(pn_data, fn_seg)} for fn_img, fn_seg in zip(df.fn_img[train_idxs], df.fn_seg_maps[train_idxs])]
val_files = [{"img": os.path.join(pn_data, fn_img), "seg": os.path.join(pn_data, fn_seg)} for fn_img, fn_seg in zip(df.fn_img[valid_idxs], df.fn_seg_maps[valid_idxs])]

print(f'df.columns:\n {df.columns.tolist()}')
df.head()

In [None]:
'''
# Pickle to npy converter
for index, item in enumerate(train_files):
    with open(item["seg"], 'rb') as file:
        arr = pickle.load(file)
        np.save(os.path.join(pn_data, item["seg"].replace('.pkl','.npy')), arr)
        '''

In [None]:
#Set deterministic training for reproducibility
set_determinism(seed=0)

In [None]:
print(len(train_files), "train files")

file_types= []
indexes = []
for index, item in enumerate(train_files):
    extension = os.path.splitext(item['img'])[1]
    if (extension not in file_types):
        file_types.append(extension)
        indexes.append(index)

print("available formats ", file_types)
print(indexes)

for i in indexes:
    print(train_files[i]['img'])
    
flag_text_loader_check = True
if flag_text_loader_check:
    fig, axs = plt.subplots(1,4)
    loader = monai.transforms.LoadImage(reader= "ITKReader")
    for i, index in enumerate(indexes):
        img = loader(train_files[index]['img'])
        arr=np.array(img[0])
        print(arr.shape)
        axs[i].imshow(arr/255)
        axs[i].set_xlabel(file_types[i])
    plt.show()


In [None]:
# define own transform: gray2rgb
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from monai.config import KeysCollection
from monai.transforms.compose import Transform, MapTransform
from skimage.color import rgb2gray

class Gray2Rgb(Transform):
    """
    Converts gray image (a single color channel) to RGB (three color channels, identical to channel 0)

    Args:
        None
    """

    def __init__(self) -> None:
        pass

    def __call__(self, img: np.ndarray) -> np.ndarray:
        """
        Apply the transform to `img`.
        """
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)
        if img.shape[-1]==1:
            img = np.concatenate((img,img,img),axis=2)
        elif img.shape[-1]==3:
            pass
        else:
            raise ValueError('Input img to Gray2Rgb needs to have 1 or three channels, but not %d.'%(img.shape[-1]))
        return img

class Gray2Rgbd(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`Gray2Rgb`.
    """

    def __init__(self, keys: KeysCollection) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            offset: offset value to shift the intensity of image.
        """
        super().__init__(keys)
        self.transform = Gray2Rgb()

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)
        for key in self.keys:
            d[key] = self.transform(d[key])
        return d

class Rgb2Gray(Transform):
    """
    Converts gray image (a single color channel) to RGB (three color channels, identical to channel 0)

    Args:
        None
    """

    def __init__(self) -> None:
        pass

    def __call__(self, img: np.ndarray) -> np.ndarray:
        """
        Apply the transform to `img`.
        """
        if img.ndim == 2:
            pass
        elif img.ndim == 3:
            assert img.shape[2]==3
            img = rgb2gray(img)
        else:
            raise ValueError('Input img to Rgb2Gray needs to have three channels, but not %d.'%(img.shape[-1]))
        return img

class Rgb2Grayd(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`Gray2Rgb`.
    """

    def __init__(self, keys: KeysCollection) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            offset: offset value to shift the intensity of image.
        """
        super().__init__(keys)
        self.transform = Rgb2Gray()

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)
        for key in self.keys:
            d[key] = self.transform(d[key])
        return d

'''
Here's a nice trick with Lambda function, from Eric Kerfoot, answered in github issue:
https://github.com/Project-MONAI/tutorials/issues/65

test_transforms = Compose(
    [
        LoadImage(image_only=True, reader=PILReader()),
        Lambda(lambda im: im[...,None] if im.ndim==2 else im),
        AsChannelsFirst(),
        Resize(crop_size, "area"),
        ToTensor(),
    ]
)

Why does this work:
test = np.random.rand(3,3)
print(test.shape)
test2 = test[...,None]
print(test2.shape)

'''

flag_test_gray2rgb = True
if flag_test_gray2rgb:
    idx = 1833
    fig, axs = plt.subplots(1,2)
    loader = monai.transforms.LoadImage(reader= "ITKReader")
    colorizer = Gray2Rgb()
    img = loader(train_files[idx]['img'])
    arr=np.array(img[0])
    print(f'Image shape before Gray2Rgb: {arr.shape}')
    axs[0].imshow(arr)
    arr=np.array(arr)
    arr=colorizer(arr)
    print(f'Image shape after Gray2Rgb: {arr.shape}')
    plt.imshow(arr/255)
    plt.show()
    
flag_test_rgb2gray = True
if flag_test_rgb2gray:
    idx = 0 #1834
    fig, axs = plt.subplots(1,2)
    loader = monai.transforms.LoadImage(reader= "ITKReader")
    decolorizer = Rgb2Gray()
    img = loader(train_files[idx]['img'])
    arr=np.array(img[0])
    print(f'Image shape before Rgb2Gray: {arr.shape}')
    axs[0].imshow(arr/255)
    arr=decolorizer(arr)
    print(f'Image shape after Rgb2Gray: {arr.shape}')
    axs[1].imshow(arr/255)
    plt.show()

In [None]:
# define transforms for image and segmentation
img_size = np.array([240,320])
rot_max = 45*np.pi/180.0
shear_max = 0.5
trans_max = tuple((img_size*0.15).astype(int))
scale_max = 0.25
'''
further ideas for augmentations
- Random colorization 
    * Uniform hue, contrast, hist.eq., white balance, sharpen
    * Torchvision ColorJitter (https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#ColorJitter)
- Gaussian noise!
    * see pupillometry video, at 250 Hz, super noisy, vertical banding pattern too
'''

def gray2rgb(x):
    #print(x.shape)
    if x.shape[0]==1:
        x = x.repeat(3, 1, 1)
        x.meta['original_channel_dim'] = -1 # THIS is the important line! 
    #print(x.shape)
    return x

def clean_tiff_meta(x):
    for key in ['DocumentName', 'ImageDescription', 'Software']:
        if key in x.meta.keys():
            del x.meta[key]
    return x

train_transforms = Compose(
    [
        #Lambdad(keys=['img', 'seg'], func=lambda x: print(x), overwrite = False),
        LoadImaged(keys=["img"], reader= ITKReader, image_only = True),
        LoadImaged(keys=["seg"], reader=NumpyReader, image_only = True),
        EnsureChannelFirstd(keys=["img"]),
        Lambdad(keys=['img'], func=lambda x: gray2rgb(x)), # gray to rgb conversion
        ScaleIntensityd(keys="img"),        
        Flipd(keys=["seg"], spatial_axis=[1]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["seg"]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["img", "seg"]), # necessary due to various readers ITKReader and NumpyReader
        Resized(keys=["img", "seg"], spatial_size=(240,320)),
        RandAdjustContrastd(keys=["seg"], prob=1.0, gamma=(0.1, 10.0)),
        RandFlipd(keys=["img", "seg"], prob=0.5, spatial_axis=[0,1]),
        RandAffined(keys=["img", "seg"], prob=0.5,
                    rotate_range=rot_max, 
                    shear_range=shear_max, 
                    translate_range=trans_max, 
                    padding_mode='zeros'),
        EnsureTyped(keys="img"),
        Lambdad(keys=['img'], func=lambda x: clean_tiff_meta(x)), # clean weird keys in TIFF metadata - turns out this is not necessary
        ToTensord(keys=["img", "seg"]),
    ]
)
#                     scale_range=scale_max, 
val_transforms = Compose(
    [
        #Lambdad(keys=['img', 'seg'], func=lambda x: print(x), overwrite = False),
        LoadImaged(keys=["img"], reader= ITKReader, image_only = True),
        LoadImaged(keys=["seg"], reader=NumpyReader, image_only = True),
        #DataStatsd(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img"]),
        Lambdad(keys=['img'], func=lambda x: gray2rgb(x)), # gray to rgb conversion
        ScaleIntensityd(keys="img"),
        Flipd(keys=["seg"], spatial_axis=[1]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["seg"]), # necessary due to various readers ITKReader and NumpyReader
        Rotate90d(keys=["img", "seg"]), # necessary due to various readers ITKReader and NumpyReader
        Resized(keys=["img", "seg"], spatial_size=(240,320)),
        EnsureTyped(keys="img"),
        Lambdad(keys=['img'], func=lambda x: clean_tiff_meta(x)), # clean weird keys in TIFF metadata - turns out this is not necessary
        ToTensord(keys=["img", "seg"]),
    ]
)
#         Rgb2Grayd(keys=["img"]),
#         ToNumpyd(keys=["img"]),

In [None]:
'''
# Pickle to npy converter
for index, item in enumerate(val_files):
    with open( item["seg"], 'rb') as file:
        arr = pickle.load(file)
        print("item ", index, item["seg"], "loaded.")        

    ylabel = np.array([])
    #print (arr)
    #print ("arr.shape ", arr.shape)

    for key, value in arr.items() :
        if (key != "useful"):
            #print (key, "\n", value)
            #print ("val ", value.sum())
            #print ("len ", len(value))
            #print ("len(value[0]) ", len(value[0]))
            tmp_arr = np.asarray(value)
            tmp_arr = tmp_arr[np.newaxis]
            #print("tmp_arr.shape ", tmp_arr.shape)
            if (ylabel.size == 0):
                ylabel = tmp_arr
            else:
                ylabel = np.append(ylabel, tmp_arr, axis=0)
            #print ("ylabel.shape in loop ", ylabel.shape)        

    print ("ylabel.shape ", ylabel.shape)

    save = True
    if (save):
        fileObject = open(item["seg"], 'wb')
        pickle.dump(ylabel, fileObject)
        fileObject.close()
        print(item["seg"], " updated.")
'''

In [None]:
def check_empty_labels():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(check_ds, batch_size=1, num_workers=0)
    empty_labels_train_indexes = []
    print ("iterating training data")
    for index, batch_data in enumerate(train_loader):
        train_labels = batch_data["seg"].to(device)
        if (train_labels.max() == 0.):
            empty_labels_train_indexes.append(index)
        if (index%(len(train_files)/5) == 0 and index != 0):
            print("%", index/len(train_files) *100, " completed")
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate)
    empty_labels_validation_indexes = []
    print ("iterating validation data")
    for index, batch_data in enumerate(val_loader):
        val_labels = batch_data["seg"].to(device)
        if (val_labels.max() == 0.):
            empty_labels_validation_indexes.append(index)
        if (index%(len(val_files)/5) == 0 and index != 0):
            print("%", index/len(val_files) * 100, " completed")
        
    return empty_labels_train_indexes, empty_labels_validation_indexes
    
empty_labels_train_indexes, empty_labels_validation_indexes = check_empty_labels()

print ("indexes of not labeled training data: (", len(empty_labels_train_indexes), "images )")
print(empty_labels_train_indexes)
for index in empty_labels_train_indexes:
    print (index, train_files[index]["img"])

print ("indexes of not labeled training data: (", len(empty_labels_validation_indexes), "images )")
print(empty_labels_validation_indexes)
for index in empty_labels_validation_indexes:
    print (index, val_files[index]["img"])

In [None]:
print("train files length before", len(train_files))
print("validation files length before", len(val_files))

for i in sorted(empty_labels_train_indexes, reverse=True):
    del train_files[i]
for i in sorted(empty_labels_validation_indexes, reverse=True):
    del val_files[i]
    
print("train files length after", len(train_files))
print("validation files length after", len(val_files))

In [None]:
'''train_transforms = Compose(
    [
        LoadImaged(keys=["img"]),
        LoadImaged(keys=["seg"], reader=NumpyReader),
        Gray2Rgbd(keys=["img"]),
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys="img"),
        RandFlipd(keys=["seg"], prob=1, spatial_axis=[1]),
        Rotate90d(keys=["seg"]),
        Resized(keys=["img", "seg"], spatial_size=(240,320)),
        RandAdjustContrastd(keys=["seg"], prob=1.0, gamma=(0.1, 10.0)),
        RandFlipd(keys=["img", "seg"], prob=0.5, spatial_axis=[0,1]),
        RandAffined(keys=["img", "seg"], prob=0.5,
                    rotate_range=rot_max, 
                    shear_range=shear_max, 
                    translate_range=trans_max, 
                    padding_mode='zeros'),
        ToTensord(keys=["img", "seg"]),
    ]
)
'''
npc = ToNumpy()
check_ds = monai.data.Dataset(data=train_files[1000:], transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=10, num_workers=0)
check_data = first(check_loader)
for i in range(1):
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    # channel first versions
    img_cf = np.squeeze(npc(check_data["img"])[i,:,:,:])
    seg_cf = np.squeeze(npc(check_data["seg"])[i,:,:,:])
    #print("npc(check_data[seg] ", npc(check_data["seg"]).shape)
    #print("seg_cf.shape", seg_cf.shape)
    #print("seg_cf.shape", seg_cf.shape)
    #print (seg_cf)
    for i in range(4):
        print (seg_cf[i,:,:].sum())
    # channel last versions for plotting
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked, = tuple([np.squeeze(seg_cf[c,:,:]) for c in [0,1,2,3]])
    img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    #plt.imshow(img)
    #plt.show()

In [None]:
# check loaded images and augmentations
# define check dataset, check data loader
npc = ToNumpy()
check_batch_size = 10
check_ds = monai.data.Dataset(data=train_files[1000:], transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=check_batch_size, num_workers=0)
check_data = first(check_loader) #check_data = monai.utils.misc.first(check_loader)
for i in range(check_batch_size):
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    # channel first versions
    img_cf = np.squeeze(npc(check_data["img"])[i,:,:,:])
    seg_cf = np.squeeze(npc(check_data["seg"])[i,:,:,:])
    # channel last versions for plotting
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(seg_cf[c,:,:]) for c in [0,1,2,3]])
    img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked)
    #plt.imshow(img)
    #plt.show()

# check filenames in minibatch
# check_data['img_meta_dict']['format']

In [None]:
'''
print(len(train_files))
        
# remove elements from list that contain given string 
# jpg files violate document name because of .mp4 in the name
train_files = [item for item in train_files if ".mp4" not in item['img']]
print("after jpgs ", len(train_files))
# tiff files procude error when you want to print batch_data['img_meta_dict']['format']
train_files = [item for item in train_files if ".tiff" not in item['img']]
print("after tiffs ", len(train_files))
'''



In [None]:
# create cached data loaders for training and validation
t0 = time.time()
train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=64,
    shuffle=True,
    num_workers=0,
#    num_workers=6,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)
# create a validation data loader
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(
    val_ds, 
    batch_size=1, 
    num_workers=0, 
#    num_workers=4,
    collate_fn=list_data_collate
)
#dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
t1 = time.time()
print('Elapsed time: %0.2f sec.'%(t1-t0))

In [None]:
max_epochs = 300

# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

model = monai.networks.nets.UNet(
    dimensions=2,
    in_channels=3,
    out_channels=5,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    dropout=0.2,
    #dropout=0.5,
    num_res_units=2,
).to(device)

loss_function = monai.losses.DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose(
    [Activations(softmax=True), AsDiscrete(threshold=0.5)]
)

# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True


inferer = SimpleInferer()

In [None]:
val_interval = 1
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_pupil = []
metric_values_iris = []
metric_values_visible = []
metric_values_glints = []
metric_values_irrelevant = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["img"].to(device),
            batch_data["seg"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        '''
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
        '''
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():

            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["img"].to(device),
                    val_data["seg"].to(device),
                )
                val_outputs = inferer(val_inputs, model)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            
            metric_pupil = metric_batch[0].item()
            metric_values_pupil.append(metric_pupil)
            metric_iris = metric_batch[1].item()
            metric_values_iris.append(metric_iris)
            metric_visible = metric_batch[2].item()
            metric_values_visible.append(metric_visible)
            metric_glints = metric_batch[3].item()
            metric_values_glints.append(metric_glints)
            metric_irrelevant = metric_batch[4].item()
            metric_values_irrelevant.append(metric_irrelevant)
            
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(model.state_dict(), "best_metric_model_dv3d_segmentation2d_dict_withdropout.pth")
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" pupil: {metric_pupil:.4f} iris: {metric_iris:.4f} visible: {metric_visible:.4f} glints: {metric_glints:.4f} irrelevant: {metric_irrelevant:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

In [None]:
# look at test results
val_loader_test = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate)
ff_model_weights = 'best_metric_model_dv3d_segmentation2d_dict_withdropout.pth'
model.load_state_dict(torch.load(ff_model_weights))
model.eval()

list_imgs  = []
list_segs  = []
list_preds = []
with torch.no_grad():
    metric_sum = 0.0
    metric_count = 0
    val_images = None
    val_labels = None
    val_outputs = None
    tgt_idx = 5
    for val_data in tqdm(val_loader_test): #val_loader
        val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
        val_outputs = inferer(val_images, model)
        list_imgs.append(val_images)
        list_segs.append(val_labels)
        list_preds.append(val_outputs)
        #value = dice_metric(y_pred=val_outputs, y=val_labels)
        #metric_count += len(value)
        #metric_sum += value.item() * len(value)
    #metric = metric_sum / metric_count
    #metric_values.append(metric)
    

In [None]:
sig_act = torch.nn.Sigmoid()
sof_act = torch.nn.Softmax()
for idx in range(100,200):
    img  = np.squeeze(list_imgs[idx].cpu().numpy()).transpose([1,2,0])
    seg  = np.squeeze(list_segs[idx].cpu().numpy()).transpose([1,2,0])
    pred = np.squeeze(sof_act(list_preds[idx]).cpu().numpy()).transpose([1,2,0])
    
    print('Image index: %d'%idx)
    fig,axs = plt.subplots(1,3,figsize=(18,6))
    axs[0].imshow(img)
    axs[0].axis('off')
    # seg
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(seg[:,:,c]) for c in [0,1,2,3]])
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked,ax=axs[1])
    pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(pred[:,:,c]) for c in [0,1,2,3]])
    #draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, tmp, ax=axs[2])
    draw_segmented_area(img, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked, ax=axs[2])
    #img = np.moveaxis(img_cf, [0,1,2], [-1,-3,-2])
    plt.show()
    


In [None]:
# Try inferring from a skimage loaded numpy array
g2r = Gray2Rgb()
acf = AsChannelFirst()
acl = AsChannelLast()
res = Resize(spatial_size=(240,320))
pred_transforms = Compose(
    [
        Gray2Rgb(),
        AsChannelFirst(),
        ScaleIntensity(),
        Resize(spatial_size=(240,320)),
        ToTensor(),
    ]
)
img_proc = pred_transforms(img)
#plt.imshow(acl(res(acf(img)))/255); plt.show()

seg_out = model(img_proc.unsqueeze(0).to(device)).detach().cpu().numpy()
seg_out = acl(np.squeeze(seg_out))

fig, axs = plt.subplots(1,2)
axs[0].imshow(img)
draw_segmented_area_4d(acl(res(acf(img)))/255, seg_out, ax=axs[1])
axs[1].axis('on')
plt.show()

# now, try for an image stack
img = imread(train_files[0]['img'])
print(img.shape)
imgs = np.stack(10*[img],axis=0)
imgs_proc = []
for I in imgs:
    img_proc = pred_transforms(I)
    imgs_proc.append(img_proc)
imgs_proc = torch.stack(imgs_proc)
print(imgs_proc.shape)

In [None]:
test = img[None, ...]
print(img.shape)
print(test.shape)

# Supplemental steps

In [None]:
from PIL import Image

# Fixing a lot of shit...
print('PNG rgb')
ff = train_files[0]['img']
print(ff)
img = Image.open(ff)
print(img.size)
print(np.asarray(img).shape)
print(dir(img))
print(img.__class__)
print(img.info)
print('\n\n')

print('PNG gray')
ff = train_files[1024]['img']
print(ff)
img = Image.open(ff)
print(img.size)
print(np.asarray(img).shape)
print(dir(img))
print(img.__class__)
print(img.info)
print('\n\n')

print('JPG')
ff = train_files[2723]['img']
print(ff)
img = Image.open(ff)
print(img.size)
print(np.asarray(img).shape)
print(dir(img))
print(img.__class__)
print(img.info)
print('\n\n')

print('BMP')
ff = train_files[2600]['img']
print(ff)
img = Image.open(ff)
print(img.size)
print(np.asarray(img).shape)
print(dir(img))
print(img.__class__)
print(img.info)
print('\n\n')

#print(img.info)

In [None]:
batch_data['img_meta_dict']['filename_or_obj']

In [None]:
for ff in batch_data['img_meta_dict']['filename_or_obj']:
    img = Image.open(ff)
    print(f'{img.info}\t: {ff}')

In [None]:
# test whether we can get rid of the 'dpi' key in img.info 
ff = '/data/mangotee/Projects/DeepVOG/data/dv3d_monai/data_dv3d_monai_QA/10262_mmuiris_subject028_ongbll3_000000.png'
#print(ff)
shutil.copy(ff,os.getcwd())
img = Image.open(ff_in)
if 'dpi' in img.info.keys():
    print('Has dpi!')

ff_in = '/data/mangotee/Projects/IEVnet/PYTHON/10262_mmuiris_subject028_ongbll3_000000.png'
ff_out = '/data/mangotee/Projects/IEVnet/PYTHON/10262_mmuiris_subject028_ongbll3_000000_corr.png'
img = Image.open(ff_in)
print(img.info)
img.info={}
img.save(ff_out)
img_corr = Image.open(ff_out)
print(img_corr.info)

In [None]:
# load seg_maps, stack into 3D array with 4 channels, write out as npy files, to use with monai.transforms.LoadNumpyd
# has already been done once
'''
Format from deepvog3D:
0/1/2/3 ... pupil_map/iris_map/combined_glints_map/visible_map
    img_shape = (pred.shape[0], pred.shape[1])
    useful_map = np.zeros(img_shape)
    pupil_map = pred[:,:,0]
    iris_map = pred[:,:,1]
    combined_glints_map = pred[:,:,2]
    visible_map = pred[:,:,3]
    useful_map[(pupil_map == 0) & (iris_map == 1) & (visible_map == 1) & (combined_glints_map == 0)] = 1
    return useful_map, (pupil_map, iris_map, combined_glints_map, visible_map)
'''
if False:
    # has to be done only once
    df = pd.read_csv(os.path.join(pn_data,'..','df_dv3d_monai_files.csv'), index_col=0)
    df['fn_seg_maps_np'] = [s.replace('.pkl','.npy') for s in df.fn_seg_maps]
    # 
    print(df.columns)

    for idx, ff_np_out in enumerate(tqdm(df.fn_seg_maps_np)):
        with open(os.path.join(pn_data,df.fn_seg_maps[idx]), 'rb') as file:
            seg_maps = pickle.load(file)
        seg_maps_np = np.zeros(seg_maps['pupil'].shape+(4,)).astype(np.bool)
        seg_maps_np[:,:,0] = seg_maps['pupil']
        seg_maps_np[:,:,1] = seg_maps['iris']
        seg_maps_np[:,:,2] = seg_maps['glints']
        seg_maps_np[:,:,3] = seg_maps['visible']
        np.save(ff_np_out, seg_maps_np)

# there is another fix we need to do - replace TIF files with PNG files
if False:
    from skimage.io import imread, imsave
    fns_png = []
    for idx, fn_img in enumerate(tqdm(df.fn_img)):
        if os.path.splitext(fn_img)[1]=='.tiff':
            img = imread(os.path.join(pn_data, fn_img))
            fn_png = fn_img.replace('.tiff','.png')
            imsave(os.path.join(pn_data, fn_png),img)
            fns_png.append(fn_png)
        else:
            fns_png.append(fn_img)
    df['fn_img_with_tiff'] = df.fn_img.tolist()
    df['fn_img'] = fns_png

# aaand another fix we need to do - replace BMP files with PNG files
if False:
    from skimage.io import imread, imsave
    fns_png = []
    for idx, fn_img in enumerate(tqdm(df.fn_img)):
        ext = os.path.splitext(fn_img)[1]
        if ext=='.bmp' or ext=='.jpg':
            img = imread(os.path.join(pn_data, fn_img))
            fn_png = fn_img.replace(ext,'.png')
            imsave(os.path.join(pn_data, fn_png),img)
            fns_png.append(fn_png)
        else:
            fns_png.append(fn_img)
    df['fn_img_with_bmp_jpg'] = df.fn_img.tolist()
    df['fn_img'] = fns_png
    # takes 2 mins

# that didn't do it either... now... we need to load each image, erase the dict in info, and save back to png
# the actual problem was that some of the images (mmuiris and delhi) had a 'dpi' key in their img.info dict
# three days lost...
if True:
    counter = 0
    for idx, fn_img in enumerate(tqdm(df.fn_img)):
        ff = os.path.join(pn_data, fn_img)
        img = Image.open(ff)
        if 'dpi' in img.info.keys():
            print(f'Has dpi! {ff}')
            counter +=1
        img.info={}
        img.save(ff)
    print('Fixed %d files with dpi tag.'%counter)
    # Fixed 952 files with dpi tag.
    # Took ~5 mins

# save the df
# df.to_csv(os.path.join(pn_data,'..','df_dv3d_monai_files.csv'))
        

# For Demo Purposes

In [None]:
val_files[1]["img"] = "E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\PYTHON\\binkoularRight.png"
val_files[0]["img"] = "E:\\Users\\BerkOlcay\\DeepVOG3D\\DeepVOG3D\\PYTHON\\binkoularLeft.png"

fig, axs = plt.subplots(1,2)
loader = monai.transforms.LoadImage(reader= "ITKReader")
for i in range(2):
    img = loader(val_files[i]['img'])
    arr=np.array(img[0])
    arr = arr[:,:,:3]
    print(arr.shape)
    axs[i].imshow(arr/255)
plt.show()

In [None]:
# look at test results
val_ds = monai.data.Dataset(data=val_files[:2], transform=val_transforms)
val_loader_test = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate)

ff_model_weights = 'best_metric_model_dv3d_segmentation2d_dict_withdropout.pth'
model.load_state_dict(torch.load(ff_model_weights))
model.eval()

list_imgs  = []
list_preds = []
with torch.no_grad():
    metric_sum = 0.0
    metric_count = 0
    val_images = None
    val_labels = None
    val_outputs = None
    tgt_idx = 5
    for val_data in tqdm(val_loader_test): #val_loader
        val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
        val_images = val_images[:,:3,:,:]
        val_outputs = inferer(val_images, model)
        list_imgs.append(val_images)
        list_preds.append(val_outputs)
        


In [None]:
## sig_act = torch.nn.Sigmoid()
img1  = np.squeeze(list_imgs[0].cpu().numpy()).transpose([1,2,0])
pred1 = np.squeeze(sof_act(list_preds[0]).cpu().numpy()).transpose([1,2,0])
img1 = np.flipud(img1)
pred1 = np.flipud(pred1)
img2  = np.squeeze(list_imgs[1].cpu().numpy()).transpose([1,2,0])
pred2 = np.squeeze(sof_act(list_preds[1]).cpu().numpy()).transpose([1,2,0])
img2 = np.flipud(img2)
pred2 = np.flipud(pred2)


fig,axs = plt.subplots(1,2,figsize=(12,9))
axs[0].imshow(img1)
axs[0].axis('off')
axs[1].imshow(img2)
axs[1].axis('off')
plt.show()

# seg
fig,axs = plt.subplots(1,2,figsize=(12,9))
pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(pred1[:,:,c]) for c in [0,1,2,3]])
draw_segmented_area(img1, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked,ax=axs[0])
pupil_map_masked, iris_map_masked, visible_map_masked, glints_map_masked = tuple([np.squeeze(pred2[:,:,c]) for c in [0,1,2,3]])
draw_segmented_area(img2, pupil_map_masked, iris_map_masked, glints_map_masked, visible_map_masked,ax=axs[1])
plt.show()
fig.savefig('segmentation.png', bbox_inches='tight')
