In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
import pandas as pd
import csv
import cv2

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
from skimage import io, transform
from skimage import color
import scipy.misc
import scipy.ndimage as ndi
from glob import glob
from pathlib import Path
from pytvision import visualization as view
from pytvision.transforms import transforms as mtrans
from tqdm import tqdm
sys.path.append('../')
from torchlib.datasets import dsxbdata
from torchlib.datasets.dsxbdata import DSXBExDataset, DSXBDataset
from torchlib.datasets import imageutl as imutl
from torchlib import utils
from torchlib.models import unetpad
from torchlib.metrics import get_metrics
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.style.use('fivethirtyeight')

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
from pytvision.transforms import transforms as mtrans
from torchlib import metrics

from torchlib.segneuralnet import SegmentationNeuralNet
from torchlib import post_processing_func

In [None]:

map_post  = post_processing_func.MAP_post()
th_post   = post_processing_func.TH_post()
wts_post  = post_processing_func.WTS_post()

normalize = mtrans.ToMeanNormalization(
    mean = (0.485, 0.456, 0.406),  
    std  = (0.229, 0.224, 0.225), 
    )

class NormalizeInverse(torchvision.transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean = (0.485, 0.456, 0.406), std  = (0.229, 0.224, 0.225)):
        mean     = torch.as_tensor(mean)
        std      = torch.as_tensor(std)
        std_inv  = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

n = NormalizeInverse()

def get_simple_transforms(pad=0):
    return transforms.Compose([
        #mtrans.CenterCrop( (1008, 1008) ),
        mtrans.ToPad( pad, pad, padding_mode=cv2.BORDER_CONSTANT ),
        mtrans.ToTensor(),
        normalize,      
    ])


def get_flip_transforms(pad=0):
    return transforms.Compose([
        #mtrans.CenterCrop( (1008, 1008) ),
        mtrans.ToRandomTransform( mtrans.VFlip(), prob=0.5 ),
        mtrans.ToRandomTransform( mtrans.HFlip(), prob=0.5 ),
        
        mtrans.ToPad( pad, pad, padding_mode=cv2.BORDER_CONSTANT ),
        mtrans.ToTensor(),
        normalize,      
    ])

def tensor2image(tensor, norm_inverse=True):
    if tensor.dim() == 4:
        tensor = tensor[0]
    if norm_inverse:
            tensor = n(tensor)
    img = tensor.cpu().numpy().transpose(1,2,0)
    img = (img * 255).clip(0, 255).astype(np.uint8)
    return img

def show(src, titles=[], suptitle="", 
         bwidth=4, bheight=4, save_file=False,
         show_axis=True, show_cbar=False, last_max=0):

    num_cols = len(src)
    
    plt.figure(figsize=(bwidth * num_cols, bheight))
    plt.suptitle(suptitle)

    for idx in range(num_cols):
        plt.subplot(1, num_cols, idx+1)
        if not show_axis: plt.axis("off")
        if idx < len(titles): plt.title(titles[idx])
        
        if idx == num_cols-1 and last_max:
            plt.imshow(src[idx]*1, vmax=last_max, vmin=0)
        else:
            plt.imshow(src[idx]*1)
        if type(show_cbar) is bool:
            if show_cbar: plt.colorbar()
        elif idx < len(show_cbar) and show_cbar[idx]:
            plt.colorbar()
        
    plt.tight_layout()
    if save_file:
        plt.savefig(save_file)
        
def show2(src, titles=[], suptitle="", 
         bwidth=4, bheight=4, save_file=False,
         show_axis=True, show_cbar=False, last_max=0):

    num_cols = len(src)//2
    
    plt.figure(figsize=(bwidth * num_cols, bheight*2))
    plt.suptitle(suptitle)

    for idx in range(num_cols*2):
        plt.subplot(2, num_cols, idx+1)
        if not show_axis: plt.axis("off")
        if idx < len(titles): plt.title(titles[idx])
        
        if idx == num_cols-1 and last_max:
            plt.imshow(src[idx]*1, vmax=last_max, vmin=0)
        else:
            plt.imshow(src[idx]*1)
        if type(show_cbar) is bool:
            if show_cbar: plt.colorbar()
        elif idx < len(show_cbar) and show_cbar[idx]:
            plt.colorbar()
        
    plt.tight_layout()
    if save_file:
        plt.savefig(save_file)
        
def get_diversity_map(preds, gt_predictionlb, th=0.5):
    max_iou = 0
    diversity_map = np.zeros_like(gt_predictionlb)
    for idx_gt in range(1, gt_predictionlb.max()):
        roi = (gt_predictionlb==idx_gt)
        max_iou = 0

        for predlb in preds:
            for idx_pred in range(1, predlb.max()):
                roi_pred  = (predlb==idx_pred)
                union = roi.astype(int) + roi_pred.astype(int)
                val, freq = np.unique(union, return_counts=True)

                if len(val)==3:
                    iou = freq[2]/(freq[1]+freq[2])
                    if iou > max_iou:
                        max_iou = iou
                if max_iou > th: break
            if max_iou >th:
                diversity_map += roi
    return diversity_map

In [None]:
pathdataset      = os.path.expanduser( '/home/chcp/Datasets' )
namedataset      = 'Seg33_1.0.4'
namedataset      = 'Seg1009_0.3.2'
#namedataset      = 'Bfhsc_1.0.0'
#'Segments_Seg1009_0.3.2_unetpad_jreg__adam_map_ransac2_1_7_1'

#namedataset      = 'FluoC2DLMSC_0.0.1'
sub_folder       = 'test'
folders_images   = 'images'
folders_contours = 'touchs'
folders_weights  = 'weights'
folders_segment  = 'outputs'
num_classes      = 2
num_channels     = 3
pad              = 0
pathname         = pathdataset + '//' + namedataset
subset           = 'test'

In [None]:
def ransac_step2(net, inputs, targets, tag=None, max_deep=3, verbose=False):
    srcs = inputs[:, :3]
    segs = inputs[:, 3:]
    lv_segs = segs#.clone()

    first = True
    final_loss = 0.0
    for lv in range(max_deep):
        n_segs = segs.shape[1]
        new_segs = []
        actual_c = 7 ** (max_deep - lv)
        if verbose: print(segs.shape, actual_c)
        actual_seg_ids = np.random.choice(range(n_segs), size=actual_c)
        step_segs = segs[:, actual_seg_ids]

        for idx in range(0, actual_c, 7):
            mini_inp = torch.cat((srcs, step_segs[:, idx:idx+7]), dim=1)


            mini_out = net(mini_inp)
            new_segs.append(mini_out.argmax(1, keepdim=True))

        segs = torch.cat(new_segs, dim=1).float()

    return final_loss, mini_out

In [None]:
model_list =  [Path(url).name for url in glob(r'/home/chcp/Code/pytorch-unet/out/SEG1009/Segments_Seg1009_0.3.2_unetpad_jreg__adam_map_ransac2_1_7_1*')]
for model_url_base in tqdm(model_list):
    pathmodel = r'/home/chcp/Code/pytorch-unet/out/SEG1009/'
    ckpt      = r'/models/model_best.pth.tar'

    net = SegmentationNeuralNet(
        patchproject=pathmodel, 
        nameproject=model_url_base, 
        no_cuda=True, parallel=False,
        seed=2021, print_freq=False,
        gpu=True
        )

    if net.load( pathmodel+model_url_base+ckpt ) is not True:
        assert(False)
    Path(f"extra/{model_url_base}").mkdir(exist_ok=True, parents=True)

    for subset in ['test']:
    
        test_data = dsxbdata.ISBIDataset(
            pathname, 
            subset, 
            folders_labels=f'labels{num_classes}c',
            count=None,
            num_classes=num_classes,
            num_channels=num_channels,
            transform=get_simple_transforms(pad=0),
            use_weight=False,
            weight_name='',
            load_segments=True,
            shuffle_segments=True,
            use_ori=1
        )
        


        test_loader = DataLoader(test_data, batch_size=1, shuffle=False, 
            num_workers=0, pin_memory=True, drop_last=False)

        softmax = torch.nn.Softmax(dim=0)
        
        wpq, wsq, wrq, total_cells = 0, 0, 0, 0

        for idx, sample in enumerate(test_loader):
            inputs, labels = sample['image'], sample['label']
            
            _, outputs = ransac_step2(net, inputs, labels)
            amax        = outputs[0].argmax(0)
            view_inputs = tensor2image(inputs[0, :3])
            view_labels = labels[0].argmax(0)
            prob = outputs[0] / outputs[0].sum(0)
            
            
            results, n_cells, preds = get_metrics(labels, outputs, post_label='map')
            predictionlb, prediction, region, output = preds
            
            wpq += results['pq'] * n_cells
            wsq += results['sq'] * n_cells
            wrq += results['rq'] * n_cells
            total_cells += n_cells
            
            res_str = f"Nreal {n_cells} | Npred {results['n_cells']} | PQ {results['pq']:0.2f} " + \
                    f"| SQ {results['sq']:0.2f} | RQ {results['rq']:0.2f}"
            
            show2([view_inputs, view_labels, amax, predictionlb, prob[0], prob[1]], show_axis=False, suptitle=res_str,
                 show_cbar=[False, False, False, False, True, True, True, True], save_file=f"extra/{model_url_base}/{namedataset}_{subset}_{idx}.png",
                 titles=['Original', 'Label', 'MAP', 'Cells', 'Prob 0', 'Prob 1'], bheight=4.5)
            

        row = [namedataset, subset, model_url_base, wpq/total_cells, wsq/total_cells, wrq/total_cells, total_cells]
        row = list(map(str, row))
        header = ["dataset", 'subset', 'model', 'WPQ', 'WSQ', "WRQ", "Cells"]
        save_file=f"extra/{model_url_base}"
        
        summary_log = "extra/summary.csv"
        
        write_header = not Path(summary_log).exists()
        with open(summary_log, 'a') as f:
            if write_header:
                f.writelines(','.join(header)+'\n')
            f.writelines(','.join(row)+'\n')