In [6]:
from skimage.io import imread,imsave
import numpy as np
import torch
from tqdm import tqdm
from TJL.utils import *
import os

In [8]:
def segment(model,img_src,k=512,s=512,thread=0.5,device=None,use_clahe=False):
    if use_clahe:
        img_src = hist_clahe(img_src)
    img_src = min_max_norm(img_src)
    H_SRC,W_SRC = img_src.shape
    H_DST = H_SRC if (H_SRC-k)%s==0 else H_SRC+(s-(H_SRC-k)%s)
    W_DST = W_SRC if (W_SRC-k)%s==0 else W_SRC+(s-(W_SRC-k)%s)
    H_NUM, W_NUM = (H_DST-k)//s+1, (W_DST-k)//s+1
    img_dst = np.zeros((H_DST,W_DST))
    img_dst[:H_SRC,:W_SRC] = img_src

    batches, has_tar = [],[]
    for h in range(H_NUM):
        for w in range(W_NUM):
            m = img_dst[h*s:h*s+k,w*s:w*s+k]
            batches.append(m)
            has_tar.append(False if m.sum()==0 else True)
    batches = np.array(batches)
    has_tar = np.array(has_tar)

    preds = []
    for m in tqdm(batches[has_tar]):
        m = torch.tensor(m).unsqueeze(0).unsqueeze(0).float()
        m = m if device is None else m.to(device)
        pred,pred_edge = model(m)
        pred = pred.detach().cpu().numpy()[0,0]
        pred_edge = pred_edge.detach().cpu().numpy()[0,0]
        pred[pred>thread]=1
        pred[pred<1]=0
        preds.append(pred)
    preds = np.array(preds)
    batches[has_tar] = preds

    i=0
    img_dst = img_dst*0
    for h in range(H_NUM):
        for w in range(W_NUM):
            img_dst[h*s:h*s+k,w*s:w*s+k] += batches[i]
            i+=1
    img_dst[img_dst>1]=1
    img_dst = img_dst[:H_SRC,:W_SRC]
    return img_dst







In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("./model/RAW/model.pth",map_location=device)



img_path = "./data/sub"
imgs = [os.path.join(img_path,f"reco_hzb_047_50cycNo2_DischarUp_50cycNo3_DischarDown_a_0005.tif.view{i}.tif") for i in range(275,411)]

for i,f in enumerate(imgs):
    img_src = imread(f)
    img_dst = segment(model,img_src,k=512,s=256,thread=0.6,device=device)
    img_dst[img_dst>0]=2
    img_dst[(img_src>50)&(img_dst<2)] = 1
    imsave(f"./result/new/{i+275:04d}.png",img_dst)