In [1]:
import cv2
import numpy as np
import os
import pandas as ps
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
# This is for the progress bar.
from tqdm.auto import tqdm
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import save_image
from model import UNet, MyDataset
from torch.utils.data import random_split
from my_metrics import calculate_metrics
import csv

In [4]:
#设置设备，选择cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('cuda' if torch.cuda.is_available() else 'cpu')
data_folder_path = '../data/FAZ/Domain1/train/imgs'
mask_folder_path = '../data/FAZ/Domain1/train/mask'

data_files = os.listdir(data_folder_path)
mask_files= os.listdir(mask_folder_path)


dataset = MyDataset(data_file=data_files,mask_files=mask_files,data_folder_path=data_folder_path,mask_folder_path=mask_folder_path)
weight_path='params/unet.pth'
#data_path=r'data'
model = UNet().to(device)

train_dataset, valid_dataset = random_split(dataset=dataset, lengths=[195, 49], generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=4,shuffle=True)

train_loader_length = len(train_loader)

opt = optim.Adam(model.parameters())
loss_fun = nn.BCELoss()
all_epoch = 50 # 设置了早停策略暂时不care

lossArr = []
dicelist = []
jclist = []
hd95list = []
assdlist = []
splist = []
recalllist = []
prelist = []

best_score = 0  # 用于保存最好的性能指标值
patience = 6  # 容忍的 epoch 数，即连续多少个 epoch 没有改进就停止训练
counter = 0  # 计数器，记录连续 epoch 未改进的次数

dropcount = 0 # 连降计数器  记录连续 epoch 改进的次数
drop_praise = 4 # 连降超过三次即使超过patience 豁免停止训练
last_score = 99 # 

prev_dice = 0
print("model starts training\n")
for current_epoch in range(all_epoch):
    model.train()
    lossItem = 0
    with tqdm(total=train_loader_length, desc=f'Epoch {current_epoch + 1}/{all_epoch}', unit='batch') as pbar:
        for idx, (image,segment_image) in enumerate(train_loader):
            image = image.to(device)
            segment_image = segment_image.to(device)
            opt.zero_grad()
            output_img = model(image)
            train_loss = loss_fun(output_img, segment_image)
            train_loss.backward()
            opt.step()
            lossItem = lossItem + train_loss.item()
             # 更新进度条
            
            pbar.update(1)  # 更新进度条
       
        print("start to eval")

        model.eval()
        [valid_loss,all_dc,all_jc,all_hd,all_assd,all_sp,all_recall,all_pre] = [0,0,0,0,0,0,0,0]
        datalen = len(valid_dataset)
        for idx,(valid_img,valid_seg_img) in enumerate(valid_loader):
            valid_img = valid_img.to(device)
            valid_seg_img = valid_seg_img.to(device)
            with torch.no_grad():
                valid_output = model(valid_img)
            for pred,target in zip(valid_output,valid_seg_img):
                dice,jaccard,hd95_score,assd_score,sp_score,recall_score,pre_score = calculate_metrics(pred,target)
                all_dc += dice/datalen
                all_jc += jaccard/datalen
                all_hd += hd95_score/datalen
                all_assd += assd_score/datalen
                all_sp += sp_score/datalen
                all_recall += recall_score/datalen
                all_pre += pre_score/datalen
                valid_loss += loss_fun(valid_output,valid_seg_img)
    
        
        current_score = valid_loss
        
        pbar.set_postfix({'Train_Loss': lossItem,'Valid_Loss':valid_loss})  # 更新进度条显示的信息
        
        if best_score > current_score:
            best_score = current_score
            # 重置计数器
            counter = 0  
            # 保存当前最佳模型权重
            torch.save(model.state_dict(), weight_path)
        else:
            counter += 1
            if current_score < last_score:
                dropcount += 1
                if dropcount >=3:
                    counter,dropcount = 0,0 
            else:
                dropcount = 0
            
            if counter >= patience and current_epoch>20:
                print(f"Early stopping at epoch {current_epoch}!")
                break  # 停止训练
                
        last_score = valid_loss
   
        
        lossArr.append(lossItem)
        dicelist.append(all_dc)
        jclist.append(all_jc)
        hd95list.append(all_hd)
        assdlist.append(all_assd)
        splist.append(all_sp)
        recalllist.append(all_recall)
        prelist.append(all_pre)
        
transposed_lists = zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)

with open('output.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'lossArr', 'dicelist', 'jclist', 'hd95list', 'assdlist', 'splist', 'recalllist', 'prelist'])

    for epoch, row in enumerate(zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)):
        writer.writerow([epoch, *row])

cuda
model starts training



Epoch 1/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 2/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 3/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 4/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 5/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 6/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 7/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 8/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 9/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 10/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 11/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 12/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 13/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 14/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 15/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 16/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 17/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 18/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 19/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 20/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 21/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 22/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 23/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 24/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 25/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 26/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 27/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 28/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval


Epoch 29/50:   0%|          | 0/49 [00:00<?, ?batch/s]

start to eval
Early stopping at epoch 28!
