__DataLoader__

In [1]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


class Generate_Dataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data_path = glob.glob(os.path.join(path,'input/*.npy'))  #读取data文件夹下所有.npy格式文件

    def __getitem__(self, index):
        data_path = self.data_path[index]
        data = np.load(data_path)      #读取输入数据
        tensor_data = torch.from_numpy(data)
        
        label_path = data_path.replace('input', 'label')
        label = np.load(label_path)    #读取标签数据
        tensor_label = torch.from_numpy(label)

        return tensor_data, tensor_label

    def __len__(self):
        return len(self.data_path)
# if __name__ == '__main__':
#     top_dataset = Generate_Dataset('./topfiles/')

#     #print("读入数据个数为：", len(top_dataset))
#     train_loader = torch.utils.data.DataLoader(dataset=top_dataset,
#                                                batch_size=1,
#                                                shuffle=True)
#     for data, label in train_loader:
#         print(data.shape)
#         print(label.shape)


__unet3d_parts__

In [2]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

class DoubleConv3d_init(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv3d_init, self).__init__()
        self.double_conv3d_init = nn.Sequential(
            nn.Conv3d(in_channels, out_channels=32, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, out_channels, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, input):
        return self.double_conv3d_init(input)


class DoubleConv3d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv3d, self).__init__()
        self.double_conv3d = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.BatchNorm3d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, input):
        return self.double_conv3d(input)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down,self).__init__()
        self.maxpool_conv3d = nn.Sequential(
            nn.MaxPool3d(kernel_size=2, stride=2, padding=0),
            DoubleConv3d(in_channels, out_channels)
        )

    def forward(self, input):
        return self.maxpool_conv3d(input)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up3d = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.conv = DoubleConv3d(in_channels, out_channels)

    def forward(self, input, x):  #x是接收的从encoder传过来的融合数据
        #print('input',input.shape)
        #print('x',x.shape)
        x1 = self.up3d(input)
        #print('x1',x1.shape)
        diffY = torch.tensor(x1.size()[3] - x.size()[3])
        diffX = torch.tensor(x1.size()[4] - x.size()[4])#特征融合部分
        diffZ = torch.tensor(x1.size()[2] - x.size()[2])
        #if x1.size()[3] > x.size()[3]:
        x3 = F.pad(x, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2))
        #print('x3',x3.shape)
        output = torch.cat([x1, x3], dim = 1)
        return self.conv(output)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv,self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1))

    def forward(self, input):
        return self.conv(input)


__unet3d_model__

In [3]:

class UNet3D(nn.Module):
    def __init__(self,in_channels, n_classes):
        super(UNet3D, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes

        #Encoder
        self.inc = DoubleConv3d_init(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)

        #Decoder
        self.up1 = Up(768, 256)
        self.up2 = Up(384, 128)
        self.up3 = Up(192, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, input):
        out1 = self.inc(input)
        print('out1.shape:',out1.shape)
        out2 = self.down1(out1)
        print('out2.shape:',out2.shape)
        out3 = self.down2(out2)
        print('out3.shape:',out3.shape)
        out4 = self.down3(out3)
        print('out4.shape:',out4.shape)
        out5 = self.up1(out4, out3)
        print('out5.shape:',out5.shape)
        out6 = self.up2(out5, out2)
        print('out6.shape:',out6.shape)
        out7 = self.up3(out6, out1)
        print('out7.shape:',out7.shape)
        logits = self.outc(out7)
        print('logits.shape:',logits.shape)
        return logits

# if __name__ == '__main__':
#     net = UNet3D(in_channels =3 ,n_classes=3)
#     print(net)
#     para = list(net.parameters())
#     print('parameters:', para)

__train__

In [4]:
from torch import optim
import torch.nn as nn
import torch


def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):

    train_dataset = Generate_Dataset(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()
    best_loss = float('inf')

    for epoch in range(epochs):
        print('epoch:',epoch)
        net.train()
        for data, label in train_loader:
            optimizer.zero_grad()

            data = data.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)

            pred = net(data)
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            loss.backward()
            optimizer.step()


if __name__ == '__main__':
    device = torch.device('cude' if torch.cuda.is_available() else 'cpu')
    net = UNet3D(1, 1)
    net.to(device=device)

    data_path = "../../../dataset/dataset/"
    train_net(net, device, data_path)



epoch: 0
out1.shape: torch.Size([1, 64, 120, 120, 120])
out2.shape: torch.Size([1, 128, 60, 60, 60])
out3.shape: torch.Size([1, 256, 30, 30, 30])
out4.shape: torch.Size([1, 512, 15, 15, 15])
out5.shape: torch.Size([1, 256, 30, 30, 30])
out6.shape: torch.Size([1, 128, 60, 60, 60])
out7.shape: torch.Size([1, 64, 120, 120, 120])
logits.shape: torch.Size([1, 1, 120, 120, 120])
Loss/train 0.7271496057510376
out1.shape: torch.Size([1, 64, 120, 120, 120])
out2.shape: torch.Size([1, 128, 60, 60, 60])
out3.shape: torch.Size([1, 256, 30, 30, 30])
out4.shape: torch.Size([1, 512, 15, 15, 15])
out5.shape: torch.Size([1, 256, 30, 30, 30])


——————————————————————————————————————————————————————————————————————————————————

__DataLoder__

In [1]:
import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


class Generate_Dataset(Dataset):
    
    def __init__(self, path):
        self.path = path
        self.data_path = glob.glob(os.path.join(path,'input/*.npy'))  #读取data文件夹下所有.npy格式文件

    def __getitem__(self, index):
        data_path = self.data_path[index]
        data = np.load(data_path)      #读取输入数据
        tensor_data = torch.from_numpy(data)
        
        label_path = data_path.replace('input', 'label')
        label = np.load(label_path)    #读取标签数据
        tensor_label = torch.from_numpy(label)

        return tensor_data, tensor_label

    def __len__(self):
        return len(self.data_path)
# if __name__ == '__main__':
#     top_dataset = Generate_Dataset('./topfiles/')

#     #print("读入数据个数为：", len(top_dataset))
#     train_loader = torch.utils.data.DataLoader(dataset=top_dataset,
#                                                batch_size=1,
#                                                shuffle=True)
#     for data, label in train_loader:
#         print(data.shape)
#         print(label.shape)


__UNet__

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channel=1, out_channel=2, training=True):
        super(UNet, self).__init__()
        self.training = training
        self.encoder1=   nn.Conv3d(in_channel, 32, 3, stride=1, padding=1)  # b, 16, 10, 10
        self.encoder2=   nn.Conv3d(32, 64, 3, stride=1, padding=1)  # b, 8, 3, 3
        self.encoder3=   nn.Conv3d(64, 128, 3, stride=1, padding=1)
        self.encoder4=   nn.Conv3d(128, 256, 3, stride=1, padding=1)
        self.encoder5=   nn.Conv3d(256, 512, 3, stride=1, padding=1)
        
        self.decoder1 = nn.Conv3d(512, 256, 3, stride=1,padding=1)  # b, 16, 5, 5
        self.decoder2 =   nn.Conv3d(256, 128, 3, stride=1, padding=1)  # b, 8, 15, 1
        self.decoder3 =   nn.Conv3d(128, 64, 3, stride=1, padding=1)  # b, 1, 28, 28
        self.decoder4 =   nn.Conv3d(64, 32, 3, stride=1, padding=1)
        self.decoder5 =   nn.Conv3d(32, 2, 3, stride=1, padding=1)
        
        self.map4 = nn.Sequential(
            nn.Conv3d(2, out_channel, 1, 1),
            nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'),
            nn.Softmax(dim =1)
        )

        # 128*128 尺度下的映射
        self.map3 = nn.Sequential(
            nn.Conv3d(64, out_channel, 1, 1),
            nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'),
            nn.Softmax(dim =1)
        )

        # 64*64 尺度下的映射
        self.map2 = nn.Sequential(
            nn.Conv3d(128, out_channel, 1, 1),
            nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'),
            nn.Softmax(dim =1)
        )

        # 32*32 尺度下的映射
        self.map1 = nn.Sequential(
            nn.Conv3d(256, out_channel, 1, 1),
            nn.Upsample(scale_factor=(16, 32, 32), mode='trilinear'),
            nn.Softmax(dim =1)
        )

    def forward(self, x):

        out = F.relu(F.max_pool3d(self.encoder1(x),2,2))
        t1 = out
        out = F.relu(F.max_pool3d(self.encoder2(out),2,2))
        t2 = out
        out = F.relu(F.max_pool3d(self.encoder3(out),2,2))
        t3 = out
        out = F.relu(F.max_pool3d(self.encoder4(out),2,2))
        # t4 = out
        # out = F.relu(F.max_pool3d(self.encoder5(out),2,2))
        
        # t2 = out
        # out = F.relu(F.interpolate(self.decoder1(out),scale_factor=(2,2,2),mode ='trilinear'))
        # print(out.shape,t4.shape)
        output1 = self.map1(out)
        out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2,2),mode ='trilinear'))
        out = torch.add(out,t3)
        output2 = self.map2(out)
        print('output2:',output2.shape)
        out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2,2),mode ='trilinear'))
        print('out:',out.shape)
        out = torch.add(out,t2)
        output3 = self.map3(out)
        out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2,2),mode ='trilinear'))
        out = torch.add(out,t1)
        
        out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2,2),mode ='trilinear'))
        output4 = self.map4(out)
        # print(out.shape)
        # print(output1.shape,output2.shape,output3.shape,output4.shape)
        if self.training is True:
            return output1, output2, output3, output4
        else:
            return output4

__train__

In [None]:
from torch import optim
import torch.nn as nn
import torch
from tqdm import tqdm

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    train_dataset = Generate_Dataset(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    best_loss = float('inf')

    for epoch in tqdm(range(epochs)):
        print('epoch:',epoch)
        net.train()
        for data, label in train_loader:
            optimizer.zero_grad()

            data = data.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)

            pred = net(data)
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            loss.backward()
            optimizer.step()


if __name__ == '__main__':
    device = torch.device('cude' if torch.cuda.is_available() else 'cpu')
    net = UNet3D(1, 1)
    net.to(device=device)

    data_path = "../../../dataset/dataset/"
    train_net(net, device, data_path)


In [6]:
from tqdm import tqdm
import time
epoches= 1000
for i in tqdm(range(epoches)):
    time.sleep(0.01)

100%|██████████| 1000/1000 [00:15<00:00, 64.03it/s]


__train（UNet）__

In [None]:
from dataset.dataset_lits_val import Val_Dataset
from dataset.dataset_lits_train import Train_Dataset

from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
import config

from models import UNet, ResUNet , KiUNet_min, SegNet

from utils import logger, weights_init, metrics, common, loss
import os
import numpy as np
from collections import OrderedDict

def val(model, val_loader, loss_func, n_labels):
    model.eval()
    val_loss = metrics.LossAverage()
    val_dice = metrics.DiceAverage(n_labels)
    with torch.no_grad():
        for idx,(data, target) in tqdm(enumerate(val_loader),total=len(val_loader)):
            data, target = data.float(), target.long()
            target = common.to_one_hot_3d(target, n_labels)
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss=loss_func(output, target)
            
            val_loss.update(loss.item(),data.size(0))
            val_dice.update(output, target)
    val_log = OrderedDict({'Val_Loss': val_loss.avg, 'Val_dice_liver': val_dice.avg[1]})
    if n_labels==3: val_log.update({'Val_dice_tumor': val_dice.avg[2]})
    return val_log

def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
    print("=======Epoch:{}=======lr:{}".format(epoch,optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in tqdm(enumerate(train_loader),total=len(train_loader)):
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target,n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss0 = loss_func(output[0], target)
        loss1 = loss_func(output[1], target)
        loss2 = loss_func(output[2], target)
        loss3 = loss_func(output[3], target)

        loss = loss3  +  alpha * (loss0 + loss1 + loss2)
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss3.item(),data.size(0))
        train_dice.update(output[3], target)

    val_log = OrderedDict({'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1]})
    if n_labels==3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
    return val_log

if __name__ == '__main__':
    args = config.args
    save_path = os.path.join('./experiments', args.save)
    if not os.path.exists(save_path): os.mkdir(save_path)
    device = torch.device('cpu' if args.cpu else 'cuda')
    # data info
    train_loader = DataLoader(dataset=Train_Dataset(args),batch_size=args.batch_size,num_workers=args.n_threads, shuffle=True)
    val_loader = DataLoader(dataset=Val_Dataset(args),batch_size=1,num_workers=args.n_threads, shuffle=False)

    # model info
    model = ResUNet(in_channel=1, out_channel=args.n_labels,training=True).to(device)

    model.apply(weights_init.init_model)#初始化训练参数
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    common.print_network(model)
    model = torch.nn.DataParallel(model, device_ids=args.gpu_id)  # multi-GPU
 
    loss = loss.TverskyLoss()

    log = logger.Train_Logger(save_path,"train_log")

    best = [0,0] # 初始化最优模型的epoch和performance
    trigger = 0  # early stop 计数器
    alpha = 0.4 # 深监督衰减系数初始值
    for epoch in range(1, args.epochs + 1):
        common.adjust_learning_rate(optimizer, epoch, args)
        train_log = train(model, train_loader, optimizer, loss, args.n_labels, alpha)
        val_log = val(model, val_loader, loss, args.n_labels)
        log.update(epoch,train_log,val_log)

        # Save checkpoint.
        state = {'net': model.state_dict(),'optimizer':optimizer.state_dict(),'epoch': epoch}
        torch.save(state, os.path.join(save_path, 'latest_model.pth'))
        trigger += 1
        if val_log['Val_dice_liver'] > best[1]:
            print('Saving best model')
            torch.save(state, os.path.join(save_path, 'best_model.pth'))
            best[0] = epoch
            best[1] = val_log['Val_dice_liver']
            trigger = 0
        print('Best performance at Epoch: {} | {}'.format(best[0],best[1]))

        # 深监督系数衰减
        if epoch % 30 == 0: alpha *= 0.8

        # early stopping
        if args.early_stop is not None:
            if trigger >= args.early_stop:
                print("=> early stopping")
                break
        torch.cuda.empty_cache()    

In [None]:
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import config
from utils import logger,common
from dataset.dataset_lits_test import Test_Datasets,to_one_hot_3d
import SimpleITK as sitk
import os
import numpy as np
from models import ResUNet
from utils.metrics import DiceAverage
from collections import OrderedDict


def predict_one_img(model, img_dataset, args):
    dataloader = DataLoader(dataset=img_dataset, batch_size=1, num_workers=0, shuffle=False)
    model.eval()
    test_dice = DiceAverage(args.n_labels)
    target = to_one_hot_3d(img_dataset.label, args.n_labels)
    
    with torch.no_grad():
        for data in tqdm(dataloader,total=len(dataloader)):
            data = data.to(device)
            output = model(data)
            # output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size
            img_dataset.update_result(output.detach().cpu())

    pred = img_dataset.recompone_result()
    pred = torch.argmax(pred,dim=1)

    pred_img = common.to_one_hot_3d(pred,args.n_labels)
    test_dice.update(pred_img, target)
    
    test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]})
    if args.n_labels==3: test_dice.update({'Dice_tumor': test_dice.avg[2]})
    
    pred = np.asarray(pred.numpy(),dtype='uint8')
    if args.postprocess:
        pass # TO DO
    pred = sitk.GetImageFromArray(np.squeeze(pred,axis=0))

    return test_dice, pred

if __name__ == '__main__':
    args = config.args
    save_path = os.path.join('./experiments', args.save)
    device = torch.device('cpu' if args.cpu else 'cuda')
    # model info
    model = ResUNet(in_channel=1, out_channel=args.n_labels,training=False).to(device)
    model = torch.nn.DataParallel(model, device_ids=args.gpu_id)  # multi-GPU
    ckpt = torch.load('{}/best_model.pth'.format(save_path))
    model.load_state_dict(ckpt['net'])

    test_log = logger.Test_Logger(save_path,"test_log")
    # data info
    result_save_path = '{}/result'.format(save_path)
    if not os.path.exists(result_save_path):
        os.mkdir(result_save_path)
    
    datasets = Test_Datasets(args.test_data_path,args=args)
    for img_dataset,file_idx in datasets:
        test_dice,pred_img = predict_one_img(model, img_dataset, args)
        test_log.update(file_idx, test_dice)
        sitk.WriteImage(pred_img, os.path.join(result_save_path, 'result-'+file_idx+'.gz'))
