In [1]:
#从训练好的模型进行推断，以生成想要的图像
from torchdiffeq import odeint_adjoint as odeint
import torch
import torch.nn as nn
import numpy as np
from Utils.Utls import *
from Utils.Loss import *
import os
import random
import time
from Network.DynamicNet import DynamicNet

In [2]:
#设置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_root_path = "./Mri_data/"
subject_path = "002_S_4654"
seq_length = 9
divide_length = 7
#结果保存路径（包括生成图像和形变场）
result_save_path = "./result-save/002_S_4654"

In [3]:
#时间换算，将年月日格式统一变成天数
def cal_time(visit_time):
    separates = visit_time.split("-")
    year = int(separates[0])
    flag_month = separates[1].split("0")
    month = int(separates[1]) if flag_month[0] != "0" else int(flag_month[1])
    flag_day = separates[2].split("0")
    day = int(separates[2]) if flag_day[0] != "0" else int(flag_day[1])

    return year*365+month*30+day


#加载特定subject的图像和时间List
def load_imgs_and_time(subject):
    time_List = os.listdir(os.path.join(data_root_path,subject))
    time_List = sorted(time_List,key = lambda x:cal_time(x))
    img_list = []
    for t in time_List:
        img_list.append(load_nii(imgPath=os.path.join(data_root_path,subject,t,"t1.nii.gz")))
    
    return img_list,time_List

In [4]:
imgs,times =load_imgs_and_time(subject_path)
#计算时间
times = [cal_time(t)/365.0 for t in times]
start_time = times[0]
times = [t-start_time for t in times]

print(len(imgs))
print(imgs[0].shape)
print(np.min(imgs[0]))
print(np.max(imgs[0]))
print(times)

#划分训练和测试的部分
train_List = imgs[0:divide_length]
test_List =imgs[divide_length:]
train_times = times[0:divide_length]
test_times = times[divide_length:]

print(len(test_times))

9
(144, 176, 144)
0.0
1.0
[0.0, 0.22465753424648938, 0.5424657534247217, 1.030136986301386, 2.05479452054783, 4.038356164383458, 5.076712328767144, 6.128767123287616, 7.15068493150693]
2


In [5]:

im_shape =train_List[0].shape
#numpy转tensor,增加batch和channel维度，方便后续输入到模型
#160*192*144
train_List = [torch.from_numpy(img).to(device).float().unsqueeze(0).unsqueeze(0) for img in train_List]
test_List = [torch.from_numpy(img).to(device).float().unsqueeze(0).unsqueeze(0) for img in test_List]
img_List = train_List+test_List

print(len(img_List))

9


In [6]:
#定义网络v
Network = DynamicNet(img_sz=im_shape,
                    smoothing_kernel='AK',
                    smoothing_win=15,
                    smoothing_pass=1,
                    ds=2,
                    bs=32
                    ).to(device)

In [7]:
#从保存的参数文件中恢复模型(第300个epoch)
savePath = "./model-save/002_S_4654/epoch-300.pkl"
Network.load_state_dict(torch.load(savePath))
Network.eval()

DynamicNet(
  (enc_conv2): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), padding_mode=replicate)
  (enc_conv3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), padding_mode=replicate)
  (enc_conv4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), padding_mode=replicate)
  (enc_conv5): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), padding_mode=replicate)
  (enc_conv6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), padding_mode=replicate)
  (lin1): Linear(in_features=1728, out_features=32, bias=True)
  (lin2): Linear(in_features=32, out_features=171072, bias=True)
  (relu): ReLU()
  (sk): AveragingKernel()
)

In [8]:
scale_factor = torch.tensor(im_shape).to(device).view(1, 3, 1, 1, 1) * 1.
ST = SpatialTransformer(im_shape).to(device)  # spatial transformer to warp image
grid = generate_grid3D_tensor(im_shape).unsqueeze(0).to(device)  # [-1,1] 1*3*144*176*144





  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
#利用训练好的模型来进行回归，得到结果
all_phi = odeint(func = Network, y0 = grid, t=torch.tensor(times).to(device),method="rk4",rtol=1e-3,atol=1e-5).to(device)
all_v = all_phi[1:] - all_phi[:-1]
all_phi = (all_phi + 1.) / 2. * scale_factor  # [-1, 1] -> voxel spacing
grid_voxel = (grid + 1.) / 2. * scale_factor  # [-1, 1] -> voxel spacing

#用MSE进行评估
regression_MSE = []


#对每一个时间点的预测进行loss计算
for n in range(1,seq_length):
    phi = all_phi[n]
    df = phi - grid_voxel  # with grid -> without grid
    warped_moving, df_with_grid = ST(img_List[0], df, return_phi=True)
    loss_mse = MSE(warped_moving,img_List[n])

    from skimage.metrics import peak_signal_noise_ratio
    print("psnr:",peak_signal_noise_ratio(warped_moving.detach().cpu().numpy(),img_List[n].detach().cpu().numpy()))    
    regression_MSE.append(loss_mse.clone().detach().cpu())
    warped_moving = warped_moving.squeeze(0).squeeze(0)

    #保存形变场及图像
    save_nii(df.permute(2,3,4,0,1).detach().cpu().numpy(), '%s/df-t%d.nii.gz' % (result_save_path,n))
    save_nii(warped_moving.detach().cpu().numpy(), '%s/warped-t%d.nii.gz' % (result_save_path,n))

print("MSE 评估结果为：")
print(regression_MSE)
print("训练部分的平均MSE为：",np.mean(regression_MSE[0:divide_length-1]))
print("测试部分的平均MSE为：",np.mean(regression_MSE[divide_length-1:]))
print("整个序列上的平均MSE为：",np.mean(regression_MSE))



psnr: 29.795058727007188
psnr: 27.548580055743322
psnr: 30.39480238770873
psnr: 26.908640197759258
psnr: 30.456358493130047
psnr: 27.721646233110043
psnr: 24.738872237485133
psnr: 24.13506675441951
MSE 评估结果为：
[tensor(0.0010), tensor(0.0018), tensor(0.0009), tensor(0.0020), tensor(0.0009), tensor(0.0017), tensor(0.0034), tensor(0.0039)]
训练部分的平均MSE为： 0.0013912758
测试部分的平均MSE为： 0.0036087064
整个序列上的平均MSE为： 0.0019456334
