In [202]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
import pandas as pd
import csv
import cv2
import shutil 
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
from skimage import io, transform
from skimage import color
from skimage import morphology
from skimage.morphology import binary_dilation
import scipy.misc
import scipy.ndimage as ndi
from glob import glob
from pathlib import Path
from pytvision import visualization as view
from pytvision.transforms import transforms as mtrans
from tqdm import tqdm

sys.path.append('../')
from torchlib.datasets import dsxbdata
from torchlib.datasets.dsxbdata import DSXBExDataset, DSXBDataset
from torchlib.datasets import imageutl as imutl
from torchlib import utils
from torchlib.models import unetpad

import matplotlib
import matplotlib.pyplot as plt
#matplotlib.style.use('fivethirtyeight')

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
from pytvision.transforms import transforms as mtrans
from torchlib import metrics

from torchlib.segneuralnet import SegmentationNeuralNet
from torchlib import post_processing_func
from collections import defaultdict

def get_simple_transforms(pad=0):
    return transforms.Compose([
        #mtrans.CenterCrop( (1008, 1008) ),
        mtrans.ToPad( pad, pad, padding_mode=cv2.BORDER_CONSTANT ),
        mtrans.ToTensor(),
        normalize,      
    ])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def show(src, titles=[], suptitle="", 
         bwidth=4, bheight=4, save_file=False,
         show_axis=True, show_cbar=False, last_max=0):

    num_cols = len(src)
    
    plt.figure(figsize=(bwidth * num_cols, bheight))
    plt.suptitle(suptitle)

    for idx in range(num_cols):
        plt.subplot(1, num_cols, idx+1)
        if not show_axis: plt.axis("off")
        if idx < len(titles): plt.title(titles[idx])
        
        if idx == num_cols-1 and last_max:
            plt.imshow(src[idx]*1, vmax=last_max, vmin=0)
        else:
            plt.imshow(src[idx]*1)
        if type(show_cbar) is bool:
            if show_cbar: plt.colorbar()
        elif idx < len(show_cbar) and show_cbar[idx]:
            plt.colorbar()
        
    plt.tight_layout()
    if save_file:
        plt.savefig(save_file)
        
class NormalizeInverse(torchvision.transforms.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean = (0.485, 0.456, 0.406), std  = (0.229, 0.224, 0.225)):
        mean     = torch.as_tensor(mean)
        std      = torch.as_tensor(std)
        std_inv  = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

n = NormalizeInverse()


normalize = mtrans.ToMeanNormalization(
    mean = (0.485, 0.456, 0.406),  
    std  = (0.229, 0.224, 0.225), 
    )

In [None]:
pathdataset      = os.path.expanduser( '/home/chcp/Datasets' )
#namedataset      = 'Seg33_1.0.3'
#namedataset      = 'Bfhsc_1.0.0'
namedataset      = 'FluoC2DLMSC_0.0.1'
sub_folder       = 'val'
folders_images   = 'images'
folders_contours = 'touchs'
folders_weights  = 'weights'
folders_segment  = 'outputs'
num_classes      = 4
num_channels     = 3
pad              = 0
pathname         = pathdataset + '//' + namedataset
subset           = 'test'

In [None]:
arcs = ['unetpad', 'unet', 'unetresnet34', 'unetresnet101', 'segnet', 'albunet']
nets = []

for arch in arcs:
    model_url_base = 'baseline_test'
    pathmodel = r'/home/chcp/Code/pytorch-unet/out/Fluo/'
    net = SegmentationNeuralNet(
        patchproject=pathmodel, 
        nameproject=model_url_base, 
        no_cuda=True, parallel=False,
        seed=2021, print_freq=False,
        gpu=True
        )

    net.create( 
        arch=arch, 
        num_output_channels=num_classes, 
        num_input_channels=3,
        loss='jreg', 
        lr=1e-3, 
        optimizer='adam',
        lrsch='fixed',
        )
    nets.append(net)

In [None]:
subset = 'test'

test_data = dsxbdata.ISBIDataset(
    pathname, 
    subset, 
    folders_labels=f'labels{num_classes}c',
    count=None,
    num_classes=num_classes,
    num_channels=num_channels,
    transform=get_simple_transforms(pad=0),
    use_weight=False,
    weight_name='',
    load_segments=False,
    shuffle_segments=False,
)

In [None]:
in01 = np.zeros([1, 3, 832, 992])
in02 = np.zeros([1, 3, 782, 1200])

#in01 = np.zeros([1, 3, 832, 1024])
#in02 = np.zeros([1, 3, 832, 1216])

#### Input 1
- (832; 992) -> (832; 1024)

#### Input 2
- (782; 1200) -> (832; 1216)

In [None]:
trial = torch.zeros([1, 3, 832, 1216])

In [None]:
results = defaultdict(lambda:[])
for name, net in zip(arcs, nets):
    results[name].append([])
    try:
        net(trial)
        results[name].append('pass')
    except: pass    
results   

In [None]:
#assert False, 'just one time'
urls = glob(new_data + '/test/outputs/*/*.tif')
for url in urls:
    
    src = cv2.imread(url, -1)
    dst = binary_dilation(src, selem=morphology.disk(1)) * 255
    dst = dst.astype(np.uint8)
    cv2.imwrite(url, dst)


In [204]:
urls = glob(pathname+'/*/*/*.tif') + glob(pathname+'/*/*/*/*.tif')

shapes = defaultdict(lambda:0)
for url in tqdm(urls):
    src = cv2.imread(url, -1)
    
    ini_h, ini_w = src.shape
    
    if src.shape == (832, 992): #832, 1024
        target_h, target_w = 832, 1024
        
    if src.shape == (782, 1200): #832, 1216
        target_h, target_w = 832, 1216
    
    diff_h = target_h - ini_h
    diff_w = target_w - ini_w
    
    reflect101 = cv2.copyMakeBorder(src, diff_h//2, diff_h//2, diff_w//2, diff_w//2,cv2.BORDER_REFLECT_101)
    
    dst = url.replace("FluoC2DLMSC_0.0.1", "FluoC2DLMSC_0.1.1")
    Path(dst).parent.mkdir(exist_ok=True, parents=True)    
    assert url!=dst
    cv2.imwrite(dst, reflect101)

100%|██████████| 411/411 [00:07<00:00, 57.16it/s]


In [219]:
sys.path.append("../torchlib/datasets")

In [220]:
import weightmaps as wms

In [222]:
urls = glob('/home/chcp/Datasets/FluoC2DLMSC_0.1.1/*/labels4c/*')

for url in tqdm(urls):
    src   = cv2.imread(url, -1)
    
    bwm   = wms.balancewm(src) 
    bwm_u = url.replace("labels4c", 'weights/BWM').replace(".tif", '')
        
    dwm   = wms.distranfwm(src, 5)
    dwm_u = url.replace("labels4c", 'weights/DWM').replace(".tif", '')
    
    saw   = wms.shapeawewm(src, 5)
    saw_u = url.replace("labels4c", 'weights/SAW').replace(".tif", '')
    
    assert url != saw_u
    assert url != bwm_u
    assert url != dwm_u
    
    Path(bwm_u).parent.mkdir(parents=True, exist_ok=True)
    Path(dwm_u).parent.mkdir(parents=True, exist_ok=True)
    Path(saw_u).parent.mkdir(parents=True, exist_ok=True)
    np.savez(dwm_u, dwm)
    np.savez(saw_u, saw)
    np.savez(bwm_u, bwm)

100%|██████████| 137/137 [21:58<00:00,  9.63s/it]
