In [1]:
import argparse
import time
import pickle
from options.test_options import TestOptions
from data import CreateDataLoader2
from models import create_model
import numpy as np
import os
import torch

from PIL import Image
import matplotlib.pyplot as plt 
import cv2
import SimpleITK as sitk
from skimage import filters, exposure
import nibabel as nib

In [2]:
def resample_image(ori_img, target_img, target_size=(256, 256)):
    target_Size = (*target_size, ori_img.GetSize()[2])  # 目标图像大小  [x,y,z] ,在z轴上保持原始图像的大小  
    target_Spacing = (2.34375, 2.34375, 4.0)            # 目标的体素块尺寸    [x,y,z] #CT尺寸的两倍 1.171875*2=2.34375 4不变
    target_origin = target_img.GetOrigin()              # 目标的起点 [x,y,z]
    target_direction = target_img.GetDirection()   
 
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(ori_img)  # 需要重新采样的目标图像
    resampler.SetSize(target_Size)		# 目标图像大小
    resampler.SetOutputOrigin(target_origin)
    resampler.SetOutputDirection(target_direction)
    resampler.SetOutputSpacing(target_Spacing)
    resampler.SetTransform(sitk.Transform())
    
    resampler.SetOutputPixelType(sitk.sitkFloat32)
    resampler.SetInterpolator(sitk.sitkBSpline)#多分辨率B样条算法 基于B样条的非刚性配准

    return resampler.Execute(ori_img)

def median_image_itk(ori_img , radius = 1):
    sitk_median = sitk.MedianImageFilter()
    sitk_median.SetRadius(radius)
    img = sitk_median.Execute(ori_img)
    return img


def postprocess(img, minn , top=99.95 , y_axis = 0.90 ,ct=False): 
    #99.95%去掉离群点
    img = sitk.GetArrayFromImage(img)
    maxx = np.percentile(img, top) 
    img = img.clip(minn,maxx)

    if ct:
        img = 2 * (img - minn) / (maxx - minn) - 1
        return img,maxx

    #算分段阈值
    middle = filters.threshold_otsu(img,nbins=int(maxx))

    imgMinMid = np.clip(img, minn, middle)
    imgMinMid = (imgMinMid - minn)/(middle-minn)*y_axis*2-1
    
    imgMidMax = np.clip(img, middle, maxx)
    imgMidMax = ((imgMidMax - middle)/(maxx-middle)*(1-y_axis) + y_axis)*2-1
    
    img = (img>=middle)*imgMidMax  + (img<middle)*imgMinMid
    return img,middle,maxx

def testpkl(path,path2,direction ='axial'):
    i = 102
    merge=False

    NifitmPathCT = path + r"CT/%04d.nii.gz" % i
    NifitmPathPET = path + r"PET/%04d.nii.gz" % i

    sitkImageCT = sitk.ReadImage(NifitmPathCT)
    sitkImagePET = sitk.ReadImage(NifitmPathPET)

    #重采样
    resample_PET = resample_image(sitkImagePET,sitkImageCT)
    resample_CT  = resample_image(sitkImageCT,sitkImageCT)

    #归一化，并记下最大值
    CT,CTmax = postprocess(resample_CT, minn =-1024 , top=99.95 , y_axis = 0.90 ,ct=True)
    PET,PETmiddle,PETmax = postprocess(resample_PET, minn =0 , top=99.95 , y_axis = 0.90 )

    CT = np.transpose(CT, (1, 2, 0))
    PET = np.transpose(PET, (1, 2, 0))

    print("第{}个影像 {} 已处理完成！".format(i,direction))   

    with open(path2+'/test.pkl', 'wb') as f:
        pickle.dump(CT .astype(np.float32), f)
        pickle.dump(PET.astype(np.float32), f)
    f.close()
    return CTmax, PETmiddle,PETmax
    
path = '/opt/data/private/pGAN-cGAN-ct-pet-pkl/petct/'
for dire in['axial']:#['axial','coronal','sagittal']
    path2= '/opt/data/private/PET生成CT/pictures/'+'/test'#'/opt/data/private/PET生成CT/pictures/'+dire+'/test'
    if not os.path.exists(path2):
        os.makedirs(path2)
    CTmax, PETmiddle,PETmax = testpkl(path,path2,direction = dire)
print(f"这个影像的CTmax:{CTmax},PETmiddle{PETmiddle},PETmax{PETmax}")

第102个影像 axial 已处理完成！
这个影像的CTmax:1261.0,PETmiddle10854.66015625,PETmax30637.451684571417


In [3]:
import os
from PIL import Image
import numpy as np

def back_images(i,visuals, image_path,  CTmax,aspect_ratio=1.0):
    image_dir = "/opt/data/private/PET生成CT/pictures/results"
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
    name = "petct102"

    for label, im in visuals.items():
        if label =='real_B':
            im_mapped = ((im - 0) / (255 - 0)) * (CTmax - -1024) + -1024
            im= np.clip(im_mapped, -1024, CTmax)
        elif label=='fake_B':
            im_mapped = ((im - 0) / (255 - 0)) * (CTmax - -1024) + -1024
            im_faked= np.clip(im_mapped, -1024, CTmax)

    return im,im_faked


In [4]:
opt = TestOptions().parse()
opt.dataroot='/opt/data/private/PET生成CT/pictures'
opt.model = 'pGAN'
opt.name = 'FCT_res'

opt.which_model_netG = 'FCT_res'

opt.which_direction =  'BtoA'
opt.phase ='test' 
opt.batchSize = 1
opt.output_nc =1 
opt.input_nc = 3
opt.how_many = 800
opt.gpu_ids = [0]
opt.norm="batch"

opt.results_dir = 'pictures/'
opt.checkpoints_dir ='checkpoints/'
opt.dataset_mode = 'aligned'
opt.display_server = "http://114.212.200.248"
opt.display_port = 25809

opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle

if len(opt.gpu_ids) > 0:
    torch.cuda.set_device(opt.gpu_ids[0])

In [5]:
model = create_model(opt)
data_loader = CreateDataLoader2(opt)
dataset = data_loader.load_data()

# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
ct_real = np.zeros((264, 256, 256))
ct_fake = np.zeros((264, 256, 256))

for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    model.set_input(data)
    model.test()      
    visuals = model.get_current_visuals()
    img_path = model.get_image_paths()
    img_path[0]=img_path[0]+str(i)
    print('%04d: process image... %s' % (i, img_path))

    im ,im_faked= back_images(i,visuals, img_path, CTmax,aspect_ratio=opt.aspect_ratio)    
    ct_real[i,:, :] = im[:,:,0]
    ct_fake[i,:, :] = im_faked[:,:,0]


pGAN


initialization method [normal]
---------- Networks initialized -------------
FCTGenerator2(
  (model_1): Sequential(
    (0): residualUnit(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (convX): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
      (bnX): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (model_2): Sequential(
    (0): residualUnit(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, mo

In [None]:
ct_real = np.transpose(ct_real, (0, 2, 1))
ct_fake = np.transpose(ct_fake, (0, 2, 1))

ct_real = sitk.GetImageFromArray(ct_real)
ct_fake = sitk.GetImageFromArray(ct_fake)




In [None]:
model = create_model(opt)

opt.name = 'resnet_9blocks'
opt.which_model_netG = 'resnet_9blocks'
model2 = create_model(opt)

In [None]:
with open('/opt/data/private/PET生成CT/pictures/axial/test/test.pkl', 'rb') as f:
    axial_ct = pickle.load(f)
    axial_pet = pickle.load(f)

opt.dataroot='/opt/data/private/PET生成CT/pictures/axial'
data_loader = CreateDataLoader2(opt)
dataset = data_loader.load_data()
visuals_axial = np.zeros([256,256,len(data_loader)])
visuals_axial2 = np.zeros([256,256,len(data_loader)])

for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    model.set_input(data)
    model.test()      
    visuals = model.get_current_visuals()

    model2.set_input(data)
    model2.test()      
    visuals2 = model2.get_current_visuals()

    img_path = model.get_image_paths()
    img_path[0]=img_path[0]+'_'+str(i)
    print('%04d: process image... %s' % (i, img_path))
    visuals_axial[:,:,i] = (visuals['fake_B'].squeeze(2))
    visuals_axial[:,:,i] = (visuals_axial[:,:,i] - (0)) / 255 * (CTmax[i] - (-1024)) + -1024

    visuals_axial2[:,:,i] = (visuals2['fake_B'].squeeze(2))
    visuals_axial2[:,:,i] = (visuals_axial2[:,:,i] - (0)) / 255 * (CTmax[i] - (-1024)) + -1024

    axial_ct[:,:,i]  = (axial_ct[:,:,i] - (-1)) / 2 * (CTmax[i] - (-1024)) + -1024
    axial_pet[:,:,i] = (axial_pet[:,:,i] - (-1)) / 2 * (PETmax[i] - (0)) + 0

In [None]:
# picpath ='/opt/data/private/PET生成CT/pictures/'
# sitkImage1 = sitk.GetImageFromArray(np.transpose(axial_ct[:,:,1:-1],(2,0,1)))
# sitk.WriteImage(sitkImage1,os.path.join(picpath,picpath+'1.nii.gz'))

# sitkImage2 = sitk.GetImageFromArray(np.transpose(axial_pet[:,:,1:-1],(2,0,1)))
# sitk.WriteImage(sitkImage2,os.path.join(picpath,picpath+'2.nii.gz'))

# sitkImage3 = sitk.GetImageFromArray(np.transpose(visuals_axial,(2,0,1)))
# sitk.WriteImage(sitkImage3,os.path.join(picpath,picpath+'3.nii.gz'))

In [None]:
def pics2nii(picpath,niipath):
    files = os.listdir(picpath)
    niiarr=np.zeros((len(files),600,800,3))

    for i in range(len(files)):
        img=sitk.ReadImage(os.path.join(picpath,files[i]))
        imgarr = sitk.GetArrayFromImage(img)
        niiarr[i,:,:,:]=imgarr
        
    niiimg=sitk.GetImageFromArray(np.uint8(niiarr))

    sitk.WriteImage(niiimg,os.path.join(picpath,picpath+'.nii.gz'))

In [None]:
fig,((ax1, ax2, ax3, ax4), (ax5, ax6, ax7, ax8)) = plt.subplots(2,4, figsize=(10, 7))
im1 = ax1.imshow(np.rot90(axial_pet[128,:,:]), cmap=plt.cm.RdBu_r)#plt.cm.RdBu
ax1.set_title('Real PET(coronal)')
ax1.set_xticks([])
ax1.set_yticks([])


im2 = ax2.imshow(np.rot90(axial_pet[:,128,:]), cmap=plt.cm.RdBu_r)#plt.cm.RdBu
ax2.set_title('Real PET(sagittal)')
ax2.set_xticks([])
ax2.set_yticks([])
ax2.axvline(x=128, color='yellow', linestyle='--', linewidth=3)
#for spine in ax2.spines.values():
#   spine.set_visible(False)

im3 = ax3.imshow(np.rot90(axial_ct[128,:,:]), cmap='gray')  
ax3.set_title('Real CT(coronal)')
ax3.set_xticks([])
ax3.set_yticks([])

im4 = ax4.imshow(np.rot90(axial_ct[:,128,:]), cmap='gray')  

ax4.set_title('Real CT(sagittal)')
ax4.set_xticks([])
ax4.set_yticks([])
#for spine in ax3.spines.values():
#    spine.set_visible(False)

plt.subplots_adjust(left=0.1, right=0.95, wspace=0)  # 调整子图和色条之间的间距


im5 = ax5.imshow(np.rot90(visuals_axial[128,:,:]), cmap='gray')
ax5.set_title('Synthetic CT(coronal,ours)')
ax5.set_xticks([])
ax5.set_yticks([])

#cv2.normalize(visuals_axial[:,128,:], None, alpha=-1,beta=1,norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
im6 = ax6.imshow(np.rot90(visuals_axial[:,128,:]), cmap='gray')
ax6.set_title('Synthetic CT(sagittal,ours)')
ax6.set_xticks([])
ax6.set_yticks([])

im7 = ax7.imshow(np.rot90(visuals_axial2[128,:,:]), cmap='gray')
ax7.set_title('Synthetic CT(coronal,pGAN)')
ax7.set_xticks([])
ax7.set_yticks([])

im8 = ax8.imshow(np.rot90(visuals_axial2[:,128,:]), cmap='gray')
ax8.set_title('Synthetic CT(sagittal,pGAN)')
ax8.set_xticks([])
ax8.set_yticks([])

# cbar_ax = fig.add_axes([-0.01, 0.655, 0.02, 0.22]) # [left, bottom, width, height]
# fig.colorbar(im1, cax=cbar_ax)
# cbar_ax_jet = fig.add_axes([-0.01, 0.385, 0.02, 0.22]) 
# fig.colorbar(im3, cax=cbar_ax_jet)  
# cbar_ax_jet = fig.add_axes([-0.01, 0.115, 0.02, 0.22]) 
# fig.colorbar(im5, cax=cbar_ax_jet) 

plt.subplots_adjust(left=0.1, right=1.2, wspace=0)  # 调整子图和色条之间的间距

plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(axial_ct[128,128,:].flatten(), linestyle='--', color='#b235e6',label='CT')
#plt.plot(axial_pet[128,128,:].flatten(), linestyle='--', color='#7cd6cf',label='PET')
plt.plot((visuals_axial[128,128,:]),
    linestyle='-', color='red',label='Synthetic CT(ours)')
plt.plot((visuals_axial2[128,128,:]),
    linestyle='-', color='#70ad47',label='Synthetic CT(pgan)')

plt.title('Pixel Value Distribution')
plt.xlabel('Pixel Index')
plt.ylabel('Pixel Value')
plt.legend(loc='best')
plt.show()

In [None]:
image_name = 'fake_PET_coronal.png' 
save_path = os.path.join(img_dir, image_name)

image_pil = Image.fromarray(visuals_numpy[128,:,:].astype(np.uint8))
image_pil.save(save_path)
