In [1]:
from torchdiffeq import odeint_adjoint as odeint
import torch
import torch.nn as nn
import numpy as np
from Utils.Utls import *
from Utils.Loss import *
import os
import random
import time
from Network.DynamicNet import DynamicNet


In [2]:
#固定随机种子，让结果可以重复(后续若要进行uncertainty计算，可将此部分去掉)
seed = 12345
np.random.seed(seed)
torch.manual_seed(seed) #CPU随机种子确定
torch.cuda.manual_seed(seed) #GPU随机种子确定
torch.cuda.manual_seed_all(seed) #所有的GPU设置种子
torch.backends.cudnn.benchmark = False #模型卷积层预先优化关闭
torch.backends.cudnn.deterministic = True #确定为默认卷积算法
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_root_path = "./Mri_data/"
subject_path = "002_S_4654"

In [4]:
#时间换算，将年月日格式统一变成天数
def cal_time(visit_time):
    separates = visit_time.split("-")
    year = int(separates[0])
    flag_month = separates[1].split("0")
    month = int(separates[1]) if flag_month[0] != "0" else int(flag_month[1])
    flag_day = separates[2].split("0")
    day = int(separates[2]) if flag_day[0] != "0" else int(flag_day[1])

    return year*365+month*30+day


#加载特定subject的图像和时间List
def load_imgs_and_time(subject):
    time_List = os.listdir(os.path.join(data_root_path,subject))
    time_List = sorted(time_List,key = lambda x:cal_time(x))
    img_list = []
    for t in time_List:
        img_list.append(load_nii(imgPath=os.path.join(data_root_path,subject,t,"t1.nii.gz")))
    
    return img_list,time_List

#返回 图像的list，时间的list

In [5]:
imgs,times =load_imgs_and_time(subject_path)
#计算时间
times = [cal_time(t)/365.0 for t in times]
start_time = times[0]
times = [t-start_time for t in times]

print(len(imgs))
print(imgs[0].shape)
print(np.min(imgs[0]))
print(np.max(imgs[0]))
print(times)

#示例subject002_S_4229的序列长度为9，这里我们按照  训练:测试=7:2 的比例 进行划分（默认是80%：20%，可根据不同subject进行调整）
train_List = imgs[0:7]
test_List =imgs[7:]
train_times = times[0:7]
test_times = times[7:]

9
(144, 176, 144)
0.0
1.0
[0.0, 0.22465753424648938, 0.5424657534247217, 1.030136986301386, 2.05479452054783, 4.038356164383458, 5.076712328767144, 6.128767123287616, 7.15068493150693]


In [6]:
im_shape =train_List[0].shape
#numpy转tensor,增加batch和channel维度，方便后续输入到模型
#144*176*144
train_List = [torch.from_numpy(img).to(device).float().unsqueeze(0).unsqueeze(0) for img in train_List]
test_List = [torch.from_numpy(img).to(device).float().unsqueeze(0).unsqueeze(0) for img in test_List]



#定义网络v(这里采用简化的版本)
Network = DynamicNet(img_sz=im_shape,
                    smoothing_kernel='AK',
                    smoothing_win=15,
                    smoothing_pass=1,
                    ds=2,
                    bs=32
                    ).to(device)

#Network = UNet_3D_ebd(n_channels=4,n_classes=3,trilinear=False).to(device)



In [7]:
#模型的保存路径
savePath = "./model-save/002_S_4654"
#定义优化器
optimizer = torch.optim.Adam(Network.parameters(), lr=0.005, amsgrad=True)
epoches = 300

# training loop
scale_factor = torch.tensor(im_shape).to(device).view(1, 3, 1, 1, 1) * 1.
ST = SpatialTransformer(im_shape).to(device)  # spatial transformer to warp image
grid = generate_grid3D_tensor(im_shape).unsqueeze(0).to(device)  # [-1,1] 1*3*144*176*144 (identity map)

#测试NCC计算
loss_NCC = NCC(win=21)
print(loss_NCC(train_List[0],train_List[1]))
print(torch.tensor([0.0]).to(device))




#用于记录每个epoch的数据
total_record = []

tensor(0.3241, device='cuda:0')
tensor([0.], device='cuda:0')


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
#训练部分
for i in range(epoches):
        #开始计时
        s_t = time.time()

        all_phi = odeint(func = Network, y0 = grid, t=torch.tensor(train_times).to(device),method="rk4",rtol=1e-3,atol=1e-5).to(device)
      
        #速度场（相邻两个时间点的状态相减，具体理解可参考euler法的定义）
        all_v = all_phi[1:] - all_phi[:-1]
        all_phi = (all_phi + 1.) / 2. * scale_factor  # [-1, 1] -> voxel spacing  恢复到标准的坐标系
        grid_voxel = (grid + 1.) / 2. * scale_factor  # [-1, 1] -> voxel spacing
        #记录各种loss
        total_loss = 0.0
        epoch_loss = []
        epoch_loss_NCC = []
        epoch_loss_MSE = []
        epoch_loss_v = []
        epoch_loss_J = []
        epoch_folding = []
        epoch_loss_df = []
        epoch_loss_bdr = []
        
        seq_length = 7

        #对每一个时间点的预测进行loss计算
        for n in range(1,seq_length):
            phi = all_phi[n]
            df = phi - grid_voxel  # with grid -> without grid（此处的df是offset）
            warped_moving, df_with_grid = ST(train_List[0], df, return_phi=True)
            # similarity loss（NCC）
            loss_sim = loss_NCC(warped_moving, train_List[n])
            epoch_loss_NCC.append(loss_sim.clone().detach().cpu())
            
            #loss_ncc = loss_NCC(warped_moving,train_List[n])
            #epoch_loss_NCC.append(loss_ncc.clone().detach().cpu())
            
            loss_mse = MSE(warped_moving,train_List[n])
            epoch_loss_MSE.append(loss_mse.clone().detach().cpu())

            warped_moving = warped_moving.squeeze(0).squeeze(0)
            # V magnitude loss
            loss_v = 0.00005 * magnitude_loss(all_v)
            epoch_loss_v.append(loss_v.clone().detach().cpu())
            # neg Jacobian loss
            loss_J = 0.000001 * neg_Jdet_loss1(df_with_grid)
            epoch_loss_J.append(loss_J.clone().detach().cpu())
            
            #folding
            folding = calculate_folding(df,device)
            epoch_folding.append(folding)
            
            
            # phi dphi/dx loss
            loss_df = 0.05 * smoothloss_loss(df)
            epoch_loss_df.append(loss_df.clone().detach().cpu())
            #bdr loss
            loss_bdr = 0.0001*boundary_loss(df)
            epoch_loss_bdr.append(loss_bdr.clone().detach().cpu())
            #各项loss求和
            loss = loss_sim + loss_df + loss_bdr
            #+ loss_v + loss_J + loss_df
            epoch_loss.append(loss.clone().detach().cpu())
            #各个时间点的loss求和
            total_loss = total_loss + loss

        optimizer.zero_grad()
        total_loss = total_loss/(seq_length-1)
        total_loss.backward()
        optimizer.step()
        
        #结束计时
        e_t = time.time()

        print("Iteration: {0} loss_NCC: {1:.3e}  loss_v: {2:.3e} loss_J: {3:.3e} loss_df: {4:.3e} total_loss: {5:.3e} time_cost: {6:.3e} loss_MSE: {7:.3e} loss_bdr: {8:.3e} folding: {9:.3e}"
              .format(i + 1, 
                      np.mean(epoch_loss_NCC),
                      np.mean(epoch_loss_v),
                      np.mean(epoch_loss_J),
                      np.mean(epoch_loss_df),
                      total_loss.item(),
                      e_t-s_t,
                      np.mean(epoch_loss_MSE),
                      np.mean(epoch_loss_bdr),
                      np.mean(epoch_folding)
                      )
                      )
        #保存每个epoch的记录
        epoch_record = {"Iteration":i+1,
                        "loss_NCC":np.mean(epoch_loss_NCC),
                        "loss_v":np.mean(epoch_loss_v),
                        "loss_J":np.mean(epoch_loss_J),
                        "folding":np.mean(epoch_folding),
                        "loss_df":np.mean(epoch_loss_df),
                        "loss_bdr":np.mean(epoch_loss_bdr),
                        "total_loss":total_loss.item(),
                        "time_cost":e_t-s_t}
        total_record.append(epoch_record)
        #写入日志
        #log_writer.add_scalar(tag="loss_NCC",scalar_value=epoch_record["loss_NCC"],global_step=i+1)
        #log_writer.add_scalar(tag="loss_v",scalar_value=epoch_record["loss_v"],global_step=i+1)
        #log_writer.add_scalar(tag="loss_J",scalar_value=epoch_record["loss_J"],global_step=i+1)
        #log_writer.add_scalar(tag="folding",scalar_value=epoch_record["folding"],global_step=i+1)
        #log_writer.add_scalar(tag="loss_df",scalar_value=epoch_record["loss_df"],global_step=i+1)
        #log_writer.add_scalar(tag="loss_bdr",scalar_value=epoch_record["loss_df"],global_step=i+1)
        #log_writer.add_scalar(tag="total_loss",scalar_value=epoch_record["total_loss"],global_step=i+1)
        #log_writer.add_scalar(tag="time_cost",scalar_value=epoch_record["time_cost"],global_step=i+1) 
        
        #每50个epoch保存一次模型
        if (i+1)%50 == 0:
            torch.save(Network.state_dict(),os.path.join(savePath,"epoch-%d.pkl"%(i+1)))
        #log_writer.close()




Iteration: 1 loss_NCC: 6.208e-01  loss_v: 3.268e-07 loss_J: 6.678e+00 loss_df: 1.752e-01 total_loss: 8.483e-01 time_cost: 1.332e+00 loss_MSE: 1.979e-02 loss_bdr: 5.240e-02 folding: 1.683e-01
Iteration: 2 loss_NCC: 5.343e-01  loss_v: 1.743e-08 loss_J: 1.083e-01 loss_df: 1.807e-02 total_loss: 5.581e-01 time_cost: 1.285e+00 loss_MSE: 9.127e-03 loss_bdr: 5.660e-03 folding: 5.844e-02
Iteration: 3 loss_NCC: 6.950e-01  loss_v: 3.326e-06 loss_J: 1.827e+02 loss_df: 1.599e+00 total_loss: 2.771e+00 time_cost: 1.286e+00 loss_MSE: 3.688e-02 loss_bdr: 4.765e-01 folding: 2.388e-01
Iteration: 4 loss_NCC: 6.604e-01  loss_v: 1.200e-06 loss_J: 4.161e+01 loss_df: 5.963e-01 total_loss: 1.436e+00 time_cost: 1.289e+00 loss_MSE: 2.850e-02 loss_bdr: 1.798e-01 folding: 2.145e-01
Iteration: 5 loss_NCC: 6.453e-01  loss_v: 1.138e-06 loss_J: 4.030e+01 loss_df: 5.860e-01 total_loss: 1.410e+00 time_cost: 1.291e+00 loss_MSE: 2.664e-02 loss_bdr: 1.789e-01 folding: 2.114e-01
Iteration: 6 loss_NCC: 4.643e-01  loss_v: 1.5