In [None]:
import os
initial_directory = os.getcwd()
os.chdir("..")

import datetime
import time
import numpy as np
import pandas as pd
import torch
import torchsummary
import torchvision.transforms as transforms
from torch import nn
from tqdm.notebook import tqdm, tnrange
from sklearn.model_selection import train_test_split
from fvcore.nn import FlopCountAnalysis, parameter_count_table
from network.PG_SN import PG_SN

%matplotlib inline
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # 以下面设置的第一个卡为主卡
os.environ["CUDA_VISIBLE_DEVICES"] = "1,0"  # 物理卡号

config={}


In [None]:
config.update(
    {
        "size": 180,
        "in_channels": 1,
        "encoder_channels": [32, 64, 128, 256, 512],
        "decoder_channels": [512, 256, 128, 64, 32],
        "out_channels": 2,
    }
)
config.update({"net_name": "PG_SN"})
net = PG_SN(config)
print(net(torch.randn(8, 1, 180, 180)).shape)
config.update({"parameters": sum(param.numel() for param in net.parameters())})
print(config["parameters"])
config.update({"flops": FlopCountAnalysis(net, (torch.randn(8, 1, 180, 180),)).total()})
print(config["flops"])

In [None]:
class My_Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        super().__init__()
        assert data.shape[1] == 3, "The data does not meet the requirements."
        self.data = self.get_data(data)

    def __getitem__(self, index):
        origin, segmentation, label = self.data[index]
        t=origin.reshape(origin.shape[0],-1)
        t=(t-t.mean(axis=1,keepdim=True))/torch.max(t.std(axis=1,keepdim=True),1.0/torch.tensor(t.shape[1]*1.0).sqrt())
        origin = t.reshape(origin.shape)
        return origin.float(), segmentation.long(), int(label)

    def __len__(self):
        return len(self.data)
    def get_data(self, data):
        total = []
        transformer = transforms.Compose([transforms.ToTensor()])
        for i in tnrange(data.shape[0], dynamic_ncols=True, desc="get_data"):
            assert (os.path.exists(data[i][0]) and os.path.isfile((data[i][0])) and os.path.exists(data[i][1]) and os.path.isfile((data[i][1])))
            origin = np.uint16(np.load(data[i][0]))
            segmentation = np.uint16(np.load(data[i][1]))
            assert (len(origin.shape) == 3 and len(segmentation.shape) == 3 and origin.shape == segmentation.shape)
            for j in range(origin.shape[2]):
                if len(np.unique(segmentation[:, :, j])) > 1:
                    total.append([transformer(np.float32(origin[:,:,j])),transformer(np.float32(segmentation[:,:,j])/65535.0),int(data[i][2])])
        return total

In [None]:
def evaluation(data_iterator, net, loss_function, device):
    net.eval()
    net = net.to(device)
    loss, number = 0.0, 0
    Acc, SE, SP, PC, F1, JS, DC = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    TP, TN, FP, FN = 0.0, 0.0, 0.0, 0.0
    with torch.no_grad():
        for X, s, _ in tqdm(data_iterator, dynamic_ncols=True, leave=False, desc="test"):
            assert net.training == False
            X = X.to(device)
            s = s.to(device)
            s_hat = net(X)
            loss += loss_function(s_hat, s[:,0,:,:]).float().cpu().item() * s.shape[0]
            s_hat = s_hat.detach().cpu().argmax(dim=1, keepdim=True).int().float()
            TP += (((s_hat.int() == 1).int() + (s.cpu().int() == 1).int()) == 2).int().float().sum().item()
            FP += (((s_hat.int() == 1).int() + (s.cpu().int() == 0).int()) == 2).int().float().sum().item()
            FN += (((s_hat.int() == 0).int() + (s.cpu().int() == 1).int()) == 2).int().float().sum().item()
            TN += (((s_hat.int() == 0).int() + (s.cpu().int() == 0).int()) == 2).int().float().sum().item()
            number += s.shape[0]
        Acc = (TP + TN) / (TP + FP + FN + TN) if TP + FP + FN + TN > 0 else 0.0# Accuracy
        SE = TP / (TP + FN) if TP + FN > 0 else 0.0# Sensitivity == Recall
        SP = TN / (TN + FP) if TN + FP > 0 else 0.0# Specificity
        PC = TP / (TP + FP) if TP + FP > 0 else 0.0# Precision
        F1 = 2 * TP / (2 * TP + FN + FP) if 2 * TP + FN + FP > 0 else 0.0# F1 == DC
        JS = TP / (TP + FN + FP) if TP + FN + FP > 0 else 0.0# Jaccard Similarity
        DC = 2 * TP / (2 * TP + FN + FP) if 2 * TP + FN + FP > 0 else 0.0# Dice Coefficient
        print('[Validation] Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (loss / number, Acc, SE, SP, PC, F1, JS, DC))
    net.train()
    return {'loss': loss / number,'Acc': Acc,'SE': SE,'SP': SP,'PC': PC,'F1': F1,'JS': JS,'DC': DC}


def train(net,train_iterator,test_iterator,loss_function,number_epochs,number_epochs_decay,optimizer,learning_rate,device,model_save_path):
    net.train()
    net = net.to(device)
    print("training on", device)
    temporary_dictionary = {
        "train_loss": [],"train_Acc": [],"train_SE": [],"train_SP": [],"train_PC": [],"train_F1": [],"train_JS": [],"train_DC": [],
        "test_loss": [],"test_Acc": [],"test_SE": [],"test_SP": [],"test_PC": [],"test_F1": [],"test_JS": [],"test_DC": [],
    }
    JS_score, DC_score = 0.0, 0.0
    for epoch in tnrange(1, number_epochs + 1, dynamic_ncols=True, desc="epoch"):
        assert net.training == True
        train_loss, number, start_time = 0.0, 0, time.time()
        TP, TN, FP, FN = 0.0, 0.0, 0.0, 0.0
        Acc, SE, SP, PC, F1, JS, DC = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        for X, s, _ in tqdm(train_iterator, dynamic_ncols=True, leave=False, desc="train"):
            assert net.training == True
            X = X.to(device)
            s = s.to(device)
            s_hat = net(X)
            loss = loss_function(s_hat, s[:,0,:,:]).float()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.float().cpu().item() * s.shape[0]
            s_hat = s_hat.detach().cpu().argmax(dim=1, keepdim=True).int().float()
            TP += (((s_hat.int() == 1).int() + (s.cpu().int() == 1).int()) == 2).int().float().sum().item()
            FP += (((s_hat.int() == 1).int() + (s.cpu().int() == 0).int()) == 2).int().float().sum().item()
            FN += (((s_hat.int() == 0).int() + (s.cpu().int() == 1).int()) == 2).int().float().sum().item()
            TN += (((s_hat.int() == 0).int() + (s.cpu().int() == 0).int()) == 2).int().float().sum().item()
            number += s.shape[0]
        Acc = (TP + TN) / (TP + FP + FN + TN) if TP + FP + FN + TN > 0 else 0.0#Accuracy
        SE = TP / (TP + FN) if TP + FN > 0 else 0.0# Sensitivity == Recall
        SP = TN / (TN + FP) if TN + FP > 0 else 0.0# Specificity
        PC = TP / (TP + FP) if TP + FP > 0 else 0.0# Precision
        F1 = 2 * TP / (2 * TP + FN + FP) if 2 * TP + FN + FP > 0 else 0.0# F1 == DC
        JS = TP / (TP + FN + FP) if TP + FN + FP > 0 else 0.0# Jaccard Similarity
        DC = 2 * TP / (2 * TP + FN + FP) if 2 * TP + FN + FP > 0 else 0.0# Dice Coefficient
        print('Epoch [%d/%d]' % (epoch, number_epochs))
        print('[Training] Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (train_loss / number, Acc, SE, SP, PC, F1, JS, DC))
        if epoch > (number_epochs - number_epochs_decay):
            learning_rate *= (1 - epoch / number_epochs) ** 0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            print('Decay learning rate to lr: {}.'.format(learning_rate))
        
        test_dictionary = evaluation(test_iterator, net, loss_function, device)

        temporary_dictionary["train_loss"].append(train_loss / number)
        temporary_dictionary["train_Acc"].append(Acc)
        temporary_dictionary["train_SE"].append(SE)
        temporary_dictionary["train_SP"].append(SP)
        temporary_dictionary["train_PC"].append(PC)
        temporary_dictionary["train_F1"].append(F1)
        temporary_dictionary["train_JS"].append(JS)
        temporary_dictionary["train_DC"].append(DC)

        temporary_dictionary["test_loss"].append(test_dictionary['loss'])
        temporary_dictionary["test_Acc"].append(test_dictionary['Acc'])
        temporary_dictionary["test_SE"].append(test_dictionary['SE'])
        temporary_dictionary["test_SP"].append(test_dictionary['SP'])
        temporary_dictionary["test_PC"].append(test_dictionary['PC'])
        temporary_dictionary["test_F1"].append(test_dictionary['F1'])
        temporary_dictionary["test_JS"].append(test_dictionary['JS'])
        temporary_dictionary["test_DC"].append(test_dictionary['DC'])

        if test_dictionary['JS'] > JS_score or test_dictionary['DC'] > DC_score:
            JS_score, DC_score = test_dictionary['JS'], test_dictionary['DC']
            print('Epoch %d : Best %s model JS score %.4f and DC score %.4f.' % (epoch, config["net_name"], test_dictionary['JS'], test_dictionary['DC']))
            torch.save(net, os.path.join(model_save_path, str(epoch) + ".pth"))
        print('Time %.1f sec' % (time.time() - start_time))
    return temporary_dictionary

In [None]:
data_file_path = os.path.abspath("../data/data_192_save_as_resampled_qu/npy/npy_data_patients_xyz.xlsx")
config.update(
    {
        "batch_size": 64,
        "learning_rate": 0.0002,
        "number_epochs": 20,
        "number_epochs_decay": 3,
        "test_size": 0.1,
    }
)
data = pd.read_excel(data_file_path)
train_data, test_data = train_test_split(
    data.values,
    test_size=config["test_size"],
    random_state=42,
    stratify=data.values[:, 2:],
)
train_dataset = My_Dataset(train_data)
test_dataset = My_Dataset(test_data)
train_iterator = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
test_iterator = torch.utils.data.DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=config["learning_rate"])
model_save_path = os.path.abspath("../model/" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
pd.DataFrame(train_data, columns=data.columns).to_excel(os.path.join(model_save_path, 'train_data.xlsx'), sheet_name="train_data", index=False)
pd.DataFrame(test_data, columns=data.columns).to_excel(os.path.join(model_save_path, 'test_data.xlsx'), sheet_name="test_data", index=False)
pd.DataFrame.from_dict(config, orient='index').to_excel(os.path.join(model_save_path, 'config.xlsx'), sheet_name="config")
xlsx_path = os.path.join(model_save_path, config["net_name"] + ".xlsx")
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config.update({"device": device})
print(config["device"])
net = net.to(device)
net = nn.DataParallel(net)
temporary_dictionary = train(
    net,
    train_iterator,
    test_iterator,
    loss_function,
    config["number_epochs"],
    config["number_epochs_decay"],
    optimizer,
    config["learning_rate"],
    config["device"],
    model_save_path,
)
torch.cuda.empty_cache()

In [None]:
if not os.path.exists(xlsx_path):
    pd.DataFrame.from_dict(temporary_dictionary,orient='columns').to_excel(xlsx_path, sheet_name=config["net_name"] + "_statistics")
else:
    writer = pd.ExcelWriter(xlsx_path, mode="a", engine="openpyxl")
    pd.DataFrame.from_dict(temporary_dictionary,orient='columns').to_excel(writer, sheet_name=config["net_name"] + "_statistics")
    writer.save()
    writer.close()