In [None]:
import numpy as np
import torch
from torch import from_numpy as from_numpy
import os
import pickle
import copy
import matplotlib.pyplot as plt
from sklearn.metrics.cluster import adjusted_rand_score
from skimage.metrics import adapted_rand_error
import pandas as pd

from func.run_pipeline_super_vox import segment_super_vox_3_channel
from func.cal_accuracy import IOU_and_Dice_Accuracy, VOI
from func.network import CellSegNet_basic_lite
from func.utils import save_obj, load_obj

In [None]:
Ovules_data_dict = load_obj("dataset_info/Ovules_dataset_info")

In [None]:
model=CellSegNet_basic_lite(input_channel=1, n_classes=3, output_func = "softmax")
load_path='output/model_Ovules.pkl'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(load_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

model.to(device)

In [None]:
print("there are test imgs: "+str(Ovules_data_dict['test'].keys()))

In [None]:
def img_3d_interpolate(img_3d, output_size, device = torch.device('cpu'), mode='nearest'):
    img_3d = img_3d.reshape(1,1,img_3d.shape[0],img_3d.shape[1],img_3d.shape[2])
    img_3d=torch.from_numpy(img_3d).float().to(device)
    img_3d=torch.nn.functional.interpolate(img_3d, size=output_size, mode='nearest')
    img_3d=img_3d.detach().cpu().numpy()
    img_3d=img_3d.reshape(img_3d.shape[2],img_3d.shape[3],img_3d.shape[4])
    
    return img_3d

def pipeline(raw_img, hand_seg, model, device,
             crop_cube_size, stride,
             how_close_are_the_super_vox_to_boundary=2,
             min_touching_area=20,
             min_touching_percentage=0.51,
             min_cell_size_threshold=10,
             transposes = [[0,1,2]], reverse_transposes = [[0,1,2]]):
    
    seg_final=segment_super_vox_3_channel(raw_img, model, device,
                                          crop_cube_size=crop_cube_size, stride=stride,
                                          how_close_are_the_super_vox_to_boundary=how_close_are_the_super_vox_to_boundary,
                                          min_touching_area=min_touching_area,
                                          min_touching_percentage=min_touching_percentage,
                                          min_cell_size_threshold=min_cell_size_threshold,
                                          transposes = transposes, reverse_transposes = reverse_transposes)
    
    seg_final_revise = copy.deepcopy(seg_final)
    seg_final_revise[hand_seg==0]=0
    
    """
    unique_vals, counts = np.unique(seg_final, return_counts=True)
    locs = np.argsort(counts)
    hand_seg_revise = copy.deepcopy(hand_seg)
    hand_seg_revise[seg_final==unique_vals[locs[::-1]][0]]=0
    """
    
    are, precision, recall = adapted_rand_error(hand_seg.astype(np.int).flatten(), seg_final.astype(np.int).flatten())
    ari = adjusted_rand_score(hand_seg.flatten(), seg_final.flatten())
    voi = VOI(seg_final.astype(np.int),hand_seg.astype(np.int))
    
    are_revise, precision_revise, recall_revise = adapted_rand_error(hand_seg.astype(np.int).flatten(), seg_final_revise.astype(np.int).flatten())
    ari_revise = adjusted_rand_score(hand_seg.flatten(), seg_final_revise.flatten())
    voi_revise = VOI(hand_seg.astype(np.int),seg_final_revise.astype(np.int))
    
    scale_factor = 0.4
    org_shape = seg_final.shape
    output_size = (int(org_shape[0]*scale_factor), int(org_shape[1]*scale_factor), int(org_shape[2]*scale_factor))
    print(str(org_shape)+" --> "+str(output_size))
    
    accuracy=IOU_and_Dice_Accuracy(img_3d_interpolate(hand_seg, output_size = output_size),
                                   img_3d_interpolate(seg_final, output_size = output_size))
    accuracy_record=accuracy.cal_accuracy_II()
    hand_seg_after_accuracy=accuracy.gt
    seg_final_after_accuracy=accuracy.pred
    
    return accuracy_record, hand_seg_after_accuracy, seg_final_after_accuracy, ari, voi, are, precision, recall, seg_final, \
are_revise, precision_revise, recall_revise, ari_revise, voi_revise

In [None]:
data_dict_test = Ovules_data_dict["test"]
data_dict_test

In [9]:
# mass process
seg_final_dict={}
accuracy_record_dict = {}
ari_dict = {}
voi_dict = {}
are_dict = {}
ari_revised_dict = {}
voi_revised_dict = {}
are_revised_dict = {}
for test_file in data_dict_test.keys():
    print(test_file)
    hf = np.load(data_dict_test[test_file])
    raw_img = np.array(hf["raw"], dtype=np.float)
    hand_seg = np.array(hf["ins"], dtype=np.float)
    print("raw_img shape: "+str(raw_img.shape))
    print("hand_seg shape: "+str(hand_seg.shape))
    
    accuracy_record, hand_seg_after_accuracy, seg_final_after_accuracy, ari, voi, are, precision, recall, seg_final, \
    are_revise, precision_revise, recall_revise, ari_revise, voi_revise=\
    pipeline(raw_img, hand_seg, model, device,
             crop_cube_size=128,
             stride=64)
    
    seg_final_dict[test_file] = seg_final
    accuracy_record_dict[test_file] = accuracy_record
    ari_dict[test_file] = ari
    voi_dict[test_file] = voi
    are_dict[test_file] = (are, precision, recall)
    ari_revised_dict[test_file] = ari_revise
    voi_revised_dict[test_file] = voi_revise
    are_revised_dict[test_file] = (are_revise, precision_revise, recall_revise)
    
    iou=np.array(accuracy_record[:,1]>0.7, dtype=np.float)
    print('cell count accuracy iou >0.7: '+str(sum(iou)/len(iou)))

    dice=np.array(accuracy_record[:,2]>0.7, dtype=np.float)
    print('cell count accuracy dice >0.7: '+str(sum(dice)/len(dice)))
    
    iou=np.array(accuracy_record[:,1]>0.5, dtype=np.float)
    print('cell count accuracy iou >0.5: '+str(sum(iou)/len(iou)))

    dice=np.array(accuracy_record[:,2]>0.5, dtype=np.float)
    print('cell count accuracy dice >0.5: '+str(sum(dice)/len(dice)))

    print('avg iou: '+str(np.mean(accuracy_record[:,1])))
    print('avg dice: '+str(np.mean(accuracy_record[:,2])))
    print("ari: "+str(ari))
    print("are, precision, recall: "+str((are, precision, recall)))
    print("voi: "+str(voi))
    print("ari_revise: "+str(ari_revise))
    print("are_revise, precision_revise, recall_revise: "+str((are_revise, precision_revise, recall_revise)))
    print("voi_revise: "+str(voi_revise))
    print("----------")

In [None]:
for item in seg_final_dict.keys():
    print(item)
    accuracy_record = accuracy_record_dict[item]
    ari = ari_dict[item]
    voi = voi_dict[item]
    (are, precision, recall) = are_dict[item]
    ari_revise = ari_revised_dict[item]
    voi_revise = voi_revised_dict[item]
    (are_revise, precision_revise, recall_revise) = are_revised_dict[item]
    iou=np.array(accuracy_record[:,1]>0.7, dtype=np.float)
    print('cell count accuracy iou >0.7: '+str(sum(iou)/len(iou)))

    dice=np.array(accuracy_record[:,2]>0.7, dtype=np.float)
    print('cell count accuracy dice >0.7: '+str(sum(dice)/len(dice)))
    
    iou=np.array(accuracy_record[:,1]>0.5, dtype=np.float)
    print('cell count accuracy iou >0.5: '+str(sum(iou)/len(iou)))

    dice=np.array(accuracy_record[:,2]>0.5, dtype=np.float)
    print('cell count accuracy dice >0.5: '+str(sum(dice)/len(dice)))

    print('avg iou: '+str(np.mean(accuracy_record[:,1])))
    print('avg dice: '+str(np.mean(accuracy_record[:,2])))
    print("ari: "+str(ari))
    print("are, precision, recall: "+str((are, precision, recall)))
    print("voi: "+str(voi))
    print("ari_revise: "+str(ari_revise))
    print("are_revise, precision_revise, recall_revise: "+str((are_revise, precision_revise, recall_revise)))
    print("voi_revise: "+str(voi_revise))
    print("----------")

In [None]:
df_show = pd.DataFrame(columns=["name", "ari", "are", "are_precision", "are_recall", "voi_split", "voi_merge",
                                "ari_revise", "are_revise", "are_precision_revise", "are_recall_revise", "voi_split_revise", "voi_merge_revise",
                                "avg iou", "avg dice", "iou>0.7", "dice>0.7", "iou>0.5", "dice>0.5"])

for item in seg_final_dict.keys():
    accuracy_record = accuracy_record_dict[item]
    ari = ari_dict[item]
    voi = voi_dict[item]
    (are, precision, recall) = are_dict[item]
    ari_revise = ari_revised_dict[item]
    voi_revise = voi_revised_dict[item]
    (are_revise, precision_revise, recall_revise) = are_revised_dict[item]
    iou=np.array(accuracy_record[:,1]>0.7, dtype=np.float)
    iou_07 = sum(iou)/len(iou)
    dice=np.array(accuracy_record[:,2]>0.7, dtype=np.float)
    dice_07 = sum(dice)/len(dice)
    iou=np.array(accuracy_record[:,1]>0.5, dtype=np.float)
    iou_05 = sum(iou)/len(iou)
    dice=np.array(accuracy_record[:,2]>0.5, dtype=np.float)
    dice_05 = sum(dice)/len(dice)
    avg_iou = np.mean(accuracy_record[:,1])
    avg_dice = np.mean(accuracy_record[:,2])
    
    data={"name": item,
          "ari": ari,
          "are": are,
          "are_precision": precision,
          "are_recall": recall,
          "voi_split": voi[0],
          "voi_merge": voi[1],
          "ari_revise": ari_revise,
          "are_revise": are_revise,
          "are_precision_revise": precision_revise,
          "are_recall_revise": recall_revise,
          "voi_split_revise": voi_revise[0],
          "voi_merge_revise": voi_revise[1],
          "avg iou": avg_iou,
          "avg dice": avg_dice,
          "iou>0.7": iou_07,
          "dice>0.7": dice_07,
          "iou>0.5": iou_05,
          "dice>0.5": dice_05}
    df_show = df_show.append(data, ignore_index=True)

In [None]:
df_show

In [None]:
df_show.mean()