In [None]:
import sys

sys.path.append('./TorchSemiSeg/exp.voc/voc8.res50v3+.CPS+CutMix/')
from network import Network
import dataloader
from config import config

import random
import torch
import torch.nn as nn
import numpy as np
from engine.engine import Engine
from custom_collate import SegCollate
import argparse
import matplotlib.pyplot as plt
import plotly.express as px

import cv2
from PIL import Image
import os
import pandas as pd

In [None]:
# torch                     1.0.0                    pypi_0    pypi
# torchsummary              1.5.1                    pypi_0    pypi
# torchvision               0.2.2.post3              pypi_0    pypi

In [None]:
LAYER_NAMES_14 = [
    "ILM", 
    "RNFL",
    "GCL", 
    "IPL", 
    "INL", 
    "OPL", 
    "ELM", 
    "PR1", 
    "PR2", 
    "RPE", 
    "Collapsed Layers",
    "Cycsts",
    "Vitreous", 
    "Choroid/Sclera"
]
total_path = "./"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def make_list(path):
    ret = []
    for root,dirs,files in os.walk(path):
        ret = sorted([path + "/" + f for f in files if f[-4:] == ".png"])
        if len(ret) > 0:
            break
    return ret

image_input_labeled_path = os.path.join(total_path, "HeidelbergTraining/images_input_labeled")
image_mask_path = os.path.join(total_path, "HeidelbergTraining/images_masks")
baseline_img_path = os.path.join(total_path, "HeidelbergTraining/images_baseline/baseline_images")
baseline_mask_path = os.path.join(total_path, "HeidelbergTraining/images_baseline/baseline_masks")
baseline_full_img_path = os.path.join(total_path, "HeidelbergTraining/images_baseline/baseline_full_image")


# go to images_input_labeled_paths.txt, images_masks_paths.txt for more info on these numbers
# 80:10:10 split at patient level
num_diseased = 292
diseased_train_cutoff = 230
normal_train_cutoff = 1381
diseased_valid_cutoff = 262
normal_valid_cutoff = 1521

images_all_labeled = make_list(image_input_labeled_path)
masks_all = make_list(image_mask_path)

image_list_train = images_all_labeled[num_diseased:normal_train_cutoff]
mask_list_train_13 = masks_all[num_diseased:normal_train_cutoff]

dimgs = images_all_labeled[:diseased_train_cutoff]
dmasks = masks_all[:diseased_train_cutoff]

while len(image_list_train) < (2 * (normal_train_cutoff - num_diseased)):
    image_list_train.extend(dimgs)
    mask_list_train_13.extend(dmasks)

del dimgs
del dmasks

d_image_list_cv = images_all_labeled[diseased_train_cutoff:diseased_valid_cutoff]
d_mask_list_cv_13 = masks_all[diseased_train_cutoff:diseased_valid_cutoff]

d_image_list_test = []
d_mask_list_test_13 = []

with open("./HeidelbergTraining/images_baseline_paths.txt", "r") as baseline_info:
    data = baseline_info.readlines()
    d_mask_list_test_13 = [elem[elem.find(":")+2:].strip() for elem in data]
    d_image_list_test = ["./HeidelbergTraining/images_input_labeled/images_input_labeled%s.png"%elem[elem.find(".png")-4:elem.find(".png")] for elem in d_mask_list_test_13]

image_list_train = sorted(image_list_train)
mask_list_train_13 = sorted(mask_list_train_13)

d_image_list_cv = sorted(d_image_list_cv)
d_mask_list_cv_13 = sorted(d_mask_list_cv_13)

d_image_list_test = sorted(d_image_list_test)
d_mask_list_test_13 = sorted(d_mask_list_test_13)

baseline_images = make_list(baseline_img_path)
baseline_masks = make_list(baseline_mask_path)
baseline_full_images = make_list(baseline_full_img_path)


In [None]:
print(d_image_list_test[0])

In [None]:
save_path = "./TorchSemiSeg/DATA/pascal_voc/subset_train_aug2/train_aug_labeled_1-8.txt"
with open(save_path, "r") as labeled_file:
    data = labeled_file.readlines()
    data = [elem.strip() for elem in data]
    
    test_imgs = ["%04d.png"%i for i in range(diseased_valid_cutoff, num_diseased)]
    for img in test_imgs:
        if img in data:
            print("NONONONONONO")

In [None]:
cps_save_path = "./TorchSemiSeg/snapshot/snapshot/epoch-last.pth"
pretrained_model_path = "./TorchSemiSeg/DATA/pytorch-weight/resnet50_v1c.pth"

print(os.path.isfile(cps_save_path))
print(os.path.isfile(pretrained_model_path))

In [None]:
def np_rearrange(image):
    image = np.concatenate((image[:, :, 0][np.newaxis, :, :], image[:, :, 1][np.newaxis, :, :], image[:, :, 2][np.newaxis, :, :]), axis=0)
    return image

def get_test_image_mask(N):
    img = np_rearrange(dataloader.normalize(cv2.imread(d_image_list_test[N]), 0, 1))
    mask = cv2.imread(d_mask_list_test_13[N], cv2.IMREAD_GRAYSCALE)

    img = torch.from_numpy(np.ascontiguousarray(img))[None, :, :, :].float().to(device)
    mask = torch.from_numpy(np.ascontiguousarray(mask))[None, :, :].long().to(device)

    return (img, mask)

def output_rearrange(output, num_classes):
    argmax_output = output[0][0].detach()[:, :, None]
    for i in range(1, num_classes):
        argmax_output = torch.cat((argmax_output, output[0][i].detach()[:, :, None]), dim=2)
    
    return torch.argmax(argmax_output, dim=-1).cpu().numpy()


In [None]:
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)

In [None]:
model = Network(
    config.num_classes, 
    criterion=criterion, 
    pretrained_model=config.pretrained_model,
    norm_layer=torch.nn.BatchNorm2d
).to(device)

In [None]:
state_dict = torch.load(cps_save_path)
model.load_state_dict(state_dict["model"])

In [None]:
def calc_iou(image, mask, n_class=0):
    actual_positives = mask == n_class
    predicted_positives = image == n_class
    
    intersection = np.sum(np.logical_and(actual_positives, predicted_positives))
    union = np.sum(np.logical_or(actual_positives, predicted_positives))
    
    if union == 0 and intersection == 0:
        return 1

    iou = intersection/union
    return iou

def get_iou(output, mask, n_classes=list(range(14))):
    ious = []
    for i in n_classes:
        if i < 10:
            ious.append(calc_iou(output, mask, i+1))
        elif i == 10:
            ious.append(calc_iou(output, mask, 0))
        elif i == 11 and len(n_classes) == 14:
            ious.append(calc_iou(output, mask, 13))
        elif len(n_classes) == 14:
            ious.append(calc_iou(output, mask, i-1))
        else:
            ious.append(calc_iou(output, mask, i))
        
    mean_iou = sum(ious) / len(n_classes)
    
    return ious, mean_iou

def add_13_label(output, image, mask):
    avg_pooler = nn.AvgPool2d(4, stride=2).to(device)
    reduced_mask = (avg_pooler(mask[None, :, :, :].float()))[0][0].detach().cpu().numpy()
    reduced_output = (avg_pooler(output[None, None, :, :].float()))[0][0].detach().cpu().numpy()
    reduced_image = (avg_pooler(image[:,0:1,:,:]))[0][0].detach().cpu().numpy()
    
    threshold = np.percentile(reduced_image, 25)

    a = [reduced_image < threshold][0].astype(np.int8)
    b = [reduced_output == 0][0].astype(np.int8)
    b = ((a + b) // 2).astype(bool)

    a = [reduced_image < threshold][0].astype(np.int8)
    c = [reduced_mask == 0][0].astype(np.int8)
    c = ((a + c) // 2).astype(bool)

    yvals, xvals = np.where(c)
    for (x,y) in zip(xvals, yvals):
        mask[0, y*2:y*2+5, x*2:x*2+5] = 13

    yvals, xvals = np.where(b)
    for (x,y) in zip(xvals, yvals):
        output[y*2:y*2+5,x*2:x*2+5] = 13
    
    return output, mask

def evaluate_CPS_model(num_images, eval_model, show=False, name="DL_pytorch", ignore_0=False, n_classes=list(range(14))):
#     eval_model.train(False)
    ious = [[] for i in range(len(n_classes))]
    means = []
    
    eval_model.eval()
    
    for i in range(num_images):
        image,mask = get_test_image_mask(i)

        output = eval_model.forward(image)
        output = torch.from_numpy(output_rearrange(output, 13))
        
        avg_pooler = nn.AvgPool2d(4, stride=2).to(device)
        reduced_mask = (avg_pooler(mask[None, :, :, :].float()))[0][0].detach().cpu().numpy()
        reduced_output = (avg_pooler(output[None, None, :, :].float()))[0][0].detach().cpu().numpy()
        reduced_image = (avg_pooler(image[:,0:1,:,:]))[0][0].detach().cpu().numpy()
        
        
        output = output.numpy()
        image = image[0][0].cpu().numpy()
        mask = mask[0].cpu().numpy()

        threshold = np.percentile(reduced_image, 25)

        a = [reduced_image < threshold][0].astype(np.int8)
        b = [reduced_output == 0][0].astype(np.int8)
        b = ((a + b) // 2).astype(bool)

        a = [reduced_image < threshold][0].astype(np.int8)
        c = [reduced_mask == 0][0].astype(np.int8)
        c = ((a + c) // 2).astype(bool)
        
        yvals, xvals = np.where(c)
        for (x,y) in zip(xvals, yvals):
            mask[y*2:y*2+5,x*2:x*2+5] = 13

        yvals, xvals = np.where(b)
        for (x,y) in zip(xvals, yvals):
            output[y*2:y*2+5,x*2:x*2+5] = 13
        
        if show:
            fig, arr = plt.subplots(1, 3, figsize=(14,10))
            arr[0].imshow(mask)
            arr[0].set_title("gt %s"%i)
            arr[1].imshow(output)
            arr[1].set_title("prediction %s"%i)
            arr[2].imshow(image)
            arr[2].set_title("input %s"%i)
            fig.show()
            plt.show()
        
        iou_c, mean_c = get_iou(mask, output, n_classes=range(14))
        
        means.append(mean_c)
        for j in range(len(ious)):
            if j >= len(iou_c):
                ious[j].append(0)
            else:
                ious[j].append(iou_c[j])
                
    means = sum(means)/len(means)
    iou_data = []
    for layer_name,layer in enumerate(ious):
        iou_data.extend([(name, layer_name, layer_i) for layer_i in layer])
        
    iou_data = pd.DataFrame(iou_data, columns=["Class", "layer", "IOU"])
    
    return iou_data,means


In [None]:
iou_cps, mean_cps = evaluate_CPS_model(len(d_image_list_test), model, name="SSL", n_classes=list(range(14)), show=True)


In [None]:
def plot_dict(iou):
    fig = px.box(iou, x="layer", y="IOU", color="Class")
    fig.update_layout(
        xaxis = dict(
            tickmode = 'array',
            tickvals = list(range(14)),
            ticktext = LAYER_NAMES_14
        ),
        boxgap=.7,
    #         title="IOUs of different model architectures on the diseased test set by layer",
        title="IOU comparison of different models on the diseased test set",
        legend=dict(
            y=-0.35,
            x=0.01
        ),
        height=650
    )
    fig.show()

plot_dict(iou_cps)
print(mean_cps)

In [None]:
iou_cps.to_csv("./HeidelbergTraining/CPS_data.csv")

In [None]:
model.eval()
for i, (N, baseline_img_path) in enumerate(zip(range(len(d_image_list_test)), baseline_full_images)):
    image,mask = get_test_image_mask(N)
    baseline_full_image = np.array(Image.open(baseline_img_path).convert("L"))
    output = model.forward(image)
    output = torch.from_numpy(output_rearrange(output, 13))
    
    output, mask = add_13_label(output, image, mask)
    output = output.numpy().astype("uint8")
    
    image_copy = Image.open(d_image_list_test[N]).convert("L")
    img_blended = Image.blend(Image.fromarray(output * 19), image_copy, 0.4)
    
    
    fig,arr = plt.subplots(1,4,figsize=(14,10))
    for a in arr:
        a.axes.get_xaxis().set_visible(False)
        a.axes.get_yaxis().set_visible(False)
    
    arr[0].imshow(image[0][0].cpu().numpy())
    arr[1].imshow(output)
    arr[2].imshow(np.array(img_blended))
    arr[3].imshow(baseline_full_image)
    fig.show()
    plt.show()
    
    