In [64]:
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from TJL.data import *
from TJL.utils import *
from TJL.loss import *
from TJL.train import Trainer
from TJL.model import UNet_2DE
import os
import shutil
from skimage.io import imsave

In [65]:
# 将路径下的所有图像改为单通道
def ch3_to_ch1(path):
    f_list = [os.path.join(path,f) for f in os.listdir(path) ]
    for f in f_list:
        img = imread(f)
        if len(img.shape)==3:
            img = img[:,:,0]
            imsave(f,img)






In [66]:
class Edge_Trainer(Trainer):
    def __init__(self, model, optim, loss_func, save_path, device, scheduler=None):
        super().__init__(model, optim, loss_func, save_path, device, scheduler)
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_list = []
        for imgs,label,edge in tqdm(train_loader):
            imgs = imgs.to(self.device)
            label = label.to(self.device)
            edge = edge.to(self.device)

            pred,pred_edge = self.model(imgs)
            loss = self.loss_func(pred,pred_edge,label,edge)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            loss_list.append(loss.item())
        return np.mean(loss_list)

In [68]:


class AL:
    def __init__(self,
                 model=None,
                 device=None
                 ):
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available else "cpu") if device is None else device
        self.model = model

    def to(self,device:str):
        self.device = torch.device(device)
    
    def load_model(self,path):
        self.model = torch.load(path,map_location=self.device)
    
    def save_model(self,path):
        torch.save(self.model,path)
    
    # 采样
    def get_sample(self,src_path,num=50,size=(256,256)):
        src_path = [os.path.join(src_path,f) for f in os.listdir(src_path)]
        samples = []
        for i in tqdm(range(num)):
            img_path = np.random.choice(src_path,1)[0]
            name = os.path.basename(img_path)[:-4]
            img = imread(img_path)
            H,W = img.shape
            h,w = np.random.choice(H-size[0],1)[0],np.random.choice(W-size[1],1)[0]
            samples.append([name,(h,w),size])
        return samples
    
    # 验证选择需要标记的图像
    def val_select(self,src_path,sample_list,select_num=20):
        self.model.to(self.device)
        self.model.eval()

        entropy_list = []
        with torch.no_grad():
            for name,hw,size in tqdm(sample_list):
                img = self.get_img_subarea(os.path.join(src_path,f"{name}.tif"),hw,size).to(self.device)
                pred,_ = self.model(img)
                pred = pred.cpu().numpy()
                entropy_list.append([(name,hw,size),self.entropy(pred)])
        entropy_list.sort(key=lambda x:-x[-1]) 
        return entropy_list[:select_num]


    # 训练模型
    def train(self,img_path,mask_path,save_path):
        train_set = ALTrainDataset(img_path,mask_path)
        train_loader = DataLoader(train_set,4)


        device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
        optim = torch.optim.Adam(self.model.parameters(),lr=3e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,T_max=50)
        loss_function = BCELoss()
        trainer = Edge_Trainer(self.model,optim,loss_function,save_path,device,scheduler=scheduler)
        trainer.train_one_epoch(train_loader)
        torch.save(self.model,os.path.join(save_path,"model.pth"))

    
    # 预测一张图像
    def pred_one_img(self,img,thread=0.5,show=False):
        img = img.to(self.device)
        self.model.to(self.device)
        self.model.eval()
        with torch.no_grad():
            pred,pred_edge = self.model(img)

        img = img.cpu().numpy()[0,0]
        pred = pred.cpu().numpy()[0,0]
        pred_edge = pred_edge.cpu().numpy()[0,0]
        entropy = min_max_norm(self.entropy(pred,"RAW"))

        pred[pred>thread]=1
        pred[pred<1]=0
        if show:
            imshow([img,pred,pred_edge,entropy])
        return img,pred,pred_edge,entropy

    # 获取图像
    def get_img_subarea(self,img_path,hw=None,size=None):
        img = min_max_norm(imread(img_path))
        if (hw is not None) and (size is not None):
            h,w = hw
            img = img[h:h+size[0],w:w+size[1]]
        img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float()
        return img
        
    # 信息熵
    def entropy(self,x,mod="MEAN"):
        x = (-x*np.log2(x))
        if mod=="MEAN":
            return x.mean()
        elif mod=="SUM":
            return x.sum()
        elif mod=="RAW":
            return x


# al = AL()
# al.load_model("./model/RAW/model.pth")

# 选择
# samples = al.get_sample("./data/AL/pool/",100)
# selected = al.val_select("./data/AL/pool/",samples,10)
# for (name,hw,size),_ in tqdm(selected):
#     img = al.get_img_subarea(os.path.join("./data/AL/pool/",f"{name}.tif"),hw,size)
#     print(img.shape)
#     img,pred,_,entropy = al.pred_one_img(img,0.5)
    
#     h,w = hw
#     imsave(f"./data/AL/selected/img/{name}_{h}_{w}.png",img)
#     imsave(f"./data/AL/selected/entropy/{name}_{h}_{w}.png",entropy)
#     imsave(f"./data/AL/selected/mask/{name}_{h}_{w}.png",pred)



# 训练
# ch3_to_ch1("./data/AL/selected/mask/")
# al.train("./data/AL/pool/","./data/AL/selected/","./model/AL/")

# 查看
# al.pred_one_img("./data/AL/sample/sub0283_740_871.png")