In [None]:
# ! pip install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,random_split
import numpy as np
from tqdm import tqdm


from maunet.nets.maxvit_unet.maxvit_unet import MaxVit_Unet

from maunet.nets.unet import *
from maunet.nets.dca_unet import DCA_UNet
from maunet.nets.bio_net import BiONet
from maunet.data import SegDataset
from maunet.utils import *
from maunet.loss import *
from maunet.train import Trainer


In [None]:
batch_size=4
num_workers=4
resize = (512,512)
trtain_set = SegDataset("../../data/tnbc/train/",resize=resize)
trtain_set,val_set = random_split(trtain_set,[0.8,0.2],generator=torch.Generator().manual_seed(42))
test_set = SegDataset("../../data/tnbc/test/",resize=resize)

train_loader = DataLoader(trtain_set,batch_size=batch_size,num_workers=num_workers)
val_loader = DataLoader(val_set,batch_size=1,num_workers=num_workers)
test_loader = DataLoader(test_set,batch_size=1,num_workers=num_workers)


print(len(trtain_set),len(val_set),len(test_set))
img,mask = next(iter(train_loader))
img.shape,mask.shape


In [None]:
device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
model = MaxVit_Unet()
optim = torch.optim.Adam(model.parameters(),lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,T_max=50)
loss_function = DiceBCE_Loss
trainer = Trainer(model,optim,loss_function,"./model/temp",device,scheduler=scheduler)
trainer.train(train_loader,val_loader,100,5)

In [None]:
model_path = "./model/temp/model.pth"

In [None]:
# 评价指标
device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")

model = torch.load(model_path,map_location=device)
model.eval()
loss_list = []
pre_list = []
acc_list = []
iou_list = []
dice_list = []
iou_min = np.inf
iou_min_bach = None
iou_min_index = None
index = 0
for imgs,label in tqdm(test_loader):
    imgs = imgs.to(device)
    label = label.to(device)

    pred = model(imgs)
    loss = loss_function(pred,label)
    loss_list.append(loss.item())

    pred[pred>0.5]=1
    pred[pred<1]=0
    pre_list.append(get_pre(pred.cpu(),label.cpu()))
    acc_list.append(get_acc(pred.cpu(),label.cpu()))
    iou_list.append(get_miou(pred.cpu(),label.cpu()))
    dice_list.append(get_dice(pred.cpu(),label.cpu()))

    miou = get_miou(pred.cpu(),label.cpu())
    if miou<iou_min:
        iou_min = miou
        iou_min_bach = (imgs,label,pred)
        iou_min_index = index
    index+=1

pre,acc,iou,dice = np.mean(pre_list),np.mean(acc_list),np.mean(iou_list),np.mean(dice_list)
print(f"test_set:\tpre:{pre:.3f}\tacc:{acc:.3f}\tdice:{dice:.3f}\tiou:{iou:.3f}")

In [None]:
model_param_count(model)/1024/1024

In [None]:
imgs,label,pred = iou_min_bach
for index in range(imgs.shape[0]):
    result_show(imgs[index].cpu().permute(1,2,0),label[index].cpu(),pred[index,0].detach().cpu())
iou_min_index

In [None]:
# 可视化
device = torch.device("cpu")

model = torch.load(model_path,map_location=device)
model.eval()
for imgs,label in tqdm(test_loader):
    imgs = imgs.to(device)
    label = label.to(device)

    pred = model(imgs)
    
    pred[pred>0.5]=1
    pred[pred<1]=0
    for index in range(imgs.shape[0]):
        result_show(imgs[index].cpu().permute(1,2,0),label[index].cpu(),pred[index,0].detach().cpu())
    break

np.mean(iou_list),np.mean(dice_list)