In [1]:
# 第一批
# ! featurize dataset download 94d0bc00-f708-4206-8a34-78a4a29386c7
# 第二批
# ! featurize dataset download 5d43ecc8-d7cd-40c3-a234-c353f3018ca7

### prepare

In [2]:
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from TJL.data import *
from TJL.utils import *
from TJL.loss import *
from TJL.model import UNet_2DE
from TJL.train import Trainer
import os


In [None]:
img_path, mask_path = "./data/raw/raw_sub_1/slice/", "./data/raw/raw_sub_1/mask/"
img_list, mask_list = [os.path.join(img_path,f) for f in os.listdir(img_path)], [os.path.join(mask_path,f) for f in os.listdir(mask_path)]
img_list.sort()
mask_list.sort()
train_set = MyDataset(img_list[:-20],mask_list[:-20],(256,256))
val_set = MyDataset(img_list[-20:],mask_list[-20:],(256,256))


# img_path, mask_path = "./data/raw/raw_sub_1/slice/", "./data/raw/raw_sub_1/slice/"
# img_list, mask_list = [os.path.join(img_path,f) for f in os.listdir(img_path)][-200:], [os.path.join(mask_path,f) for f in os.listdir(mask_path)][-200:]
# img_list.sort()
# mask_list.sort()
# test_set = MyDataset(img_list,mask_list,use_clahe=True)


train_loader = DataLoader(train_set,4)
val_loader = DataLoader(val_set,1)
# test_loader = DataLoader(test_set,1)

img,mask,edge = next(iter(train_loader))
# imshow([img[0,0].numpy(),mask[0,0].numpy()])
img.shape,mask.shape,edge.shape




### Train

In [4]:
class Edge_Trainer(Trainer):
    def __init__(self, model, optim, loss_func, save_path, device, scheduler=None):
        super().__init__(model, optim, loss_func, save_path, device, scheduler)
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_list = []
        for imgs,label,edge in tqdm(train_loader):
            imgs = imgs.to(self.device)
            label = label.to(self.device)
            edge = edge.to(self.device)

            pred,pred_edge = self.model(imgs)
            loss = self.loss_func(pred,pred_edge,label,edge)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            loss_list.append(loss.item())
        return np.mean(loss_list)

    def val_one_epoch(self,val_loader):
        self.model.eval()
        loss_list = []
        for imgs,label,edge in tqdm(val_loader):
            imgs = imgs.to(self.device)
            label = label.to(self.device)
            edge = edge.to(self.device)

            pred,pred_edge = self.model(imgs)
            loss = self.loss_func(pred,pred_edge,label,edge)
            loss_list.append(loss.item())
        return np.mean(loss_list)


In [None]:
device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
model = UNet_2DE(in_ch=1)
optim = torch.optim.Adam(model.parameters(),lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,T_max=50)
loss_function = BCELoss()
trainer = Edge_Trainer(model,optim,loss_function,"./model/RAW/",device,scheduler=scheduler)
trainer.train(train_loader,val_loader,100,10)

### Val

In [6]:
model_path = "./model/RAW/model.pth"

In [None]:
# 评价指标
device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
model = torch.load(model_path,map_location=device)
model.eval()
pre_list = []
acc_list = []
iou_list = []
dice_list = []
recall_list = []
iou_min = np.inf
iou_min_bach = None
iou_min_index = None
index = 0

thread = 0.5
for imgs,label,edge in tqdm(val_loader):
    imgs = imgs.to(device)
    label = label.to(device)

    pred,pred_edge = model(imgs)

    pred[pred>thread]=1
    pred[pred<1]=0
    pre_list.append(get_pre(pred.cpu(),label.cpu()))
    acc_list.append(get_acc(pred.cpu(),label.cpu()))
    iou_list.append(get_miou(pred.cpu(),label.cpu()))
    dice_list.append(get_dice(pred.cpu(),label.cpu()))
    recall_list.append(get_recall(pred.cpu(),label.cpu()))

    miou = get_miou(pred.cpu(),label.cpu())
    if miou<iou_min:
        iou_min = miou
        iou_min_bach = (imgs,label,pred)
        iou_min_index = index
    index+=1

pre,acc,iou,dice,recall = np.mean(pre_list),np.mean(acc_list),np.mean(iou_list),np.mean(dice_list),np.mean(recall_list)
print(f"test_set:\tpre:{pre:.3f}\tacc:{acc:.3f}\tdice:{dice:.3f}\tiou:{iou:.3f}\trecall:{recall:.3f}")