# 两期神经网络训练

In [None]:
import math
import os
import json

import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from tqdm import trange

import dataset
from Nets import TS_Net, TS_DenseNet 
from Loss_Func import SSIM_Loss, MSE_Loss
from Model_Test import ModelTest

## 1、载入数据

In [None]:
LDCT_path = r'E:\NBIA\Sampling\Dataset\LDCT5set'
NDCT_path = r'E:\NBIA\Sampling\Dataset\NDCT5set'

my_totensor = dataset.My_ToTensor()
my_normalize = dataset.My_Normalize(0.1225, 0.1188)
transform = dataset.My_Compose([my_totensor])
normalize = dataset.My_Compose([my_normalize])

train_set = dataset.Mydataset(LDCT_root = LDCT_path, NDCT_root = NDCT_path, train = True, 
                      transform = transform, 
                      normalize = normalize)
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size = 16, 
                                           num_workers = 4,
                                           shuffle = True,)
print(train_set.len)

## 2、构建神经网络

In [None]:
net = TS_Net()
print(net)

## 3、定义损失函数和优化器

In [None]:
criterion_MSE = nn.MSELoss(reduction = 'mean') #定义损失函数：均方误差
criterion_SSIM = SSIM_Loss()
optimizer_MSE = optim.Adam(net.S.parameters(), lr = 0.001) #定义优化方法，Adam
optimizer_SSIM = optim.Adam(net.F.parameters(), lr = 0.001) #定义优化方法，Adam

## 4、模型训练
### a.参数加载

In [None]:
#------训练设置------#
model_name = 'TS_Net'
train_epoch_num = 5
First_train = False
train_adjust = False

#----------------------------分割线----------------------------#
loss_path_MSE = r'./{0}/{0}_Loss_MSE.npy'.format(model_name)
loss_path_SSIM = r'./{0}/{0}_Loss_SSIM.npy'.format(model_name)
epochs_loss_path_MSE = r'./{0}/{0}_Epochs_Loss_MSE.npy'.format(model_name)
epochs_loss_path_SSIM = r'./{0}/{0}_Epochs_Loss_SSIM.npy'.format(model_name)
if not os.path.isdir('./{}'.format(model_name)):
    os.mkdir('./{}'.format(model_name))
if First_train:
    epoch_now = 0
    all_epoch_loss_MSE = []
    epochs_loss_MSE = []
    all_epoch_loss_SSIM = []
    epochs_loss_SSIM = []
    
else:
    model_num = int(input('The num of the last Epoch: '))
    model_path = r'./{0}/checkpoint/{0}_epoch_{1}.ckpt'.format(model_name, model_num)
    checkpoint = torch.load(model_path)
    net.F.load_state_dict(checkpoint['net.F'])
    net.S.load_state_dict(checkpoint['net.S'])
    optimizer_MSE.load_state_dict(checkpoint['optimizer_MSE'])
    for state in optimizer_MSE.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()
    optimizer_SSIM.load_state_dict(checkpoint['optimizer_SSIM'])
    for state in optimizer_SSIM.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()
    epoch_now = checkpoint['epoch']

    all_epoch_loss_MSE = np.load(loss_path_MSE)
    all_epoch_loss_MSE = list(all_epoch_loss_MSE)
    epochs_loss_MSE = np.load(epochs_loss_path_MSE)
    epochs_loss_MSE = list(epochs_loss_MSE)
    all_epoch_loss_SSIM = np.load(loss_path_SSIM)
    all_epoch_loss_SSIM = list(all_epoch_loss_SSIM)
    epochs_loss_SSIM = np.load(epochs_loss_path_SSIM)
    epochs_loss_SSIM = list(epochs_loss_SSIM)
    if train_adjust:
        file_path = r'./{0}/{0}-{1}.json'.format(model_name,model_num)
        with open(file_path, 'r') as f:
            d = json.load(f)
        all_epoch_loss_MSE  = d['all_epoch_loss_MSE']
        all_epoch_loss_SSIM = d['all_epoch_loss_SSIM']
        epochs_loss_SSIM = d['epochs_loss_SSIM']
        epochs_loss_MSE  = d['epochs_loss_MSE']
    print('Epoch:', epoch_now)
    print(len(all_epoch_loss_MSE))
    
scheduler_MSE = optim.lr_scheduler.ExponentialLR(optimizer_MSE, gamma = 0.75, last_epoch = epoch_now-1, verbose=True)
scheduler_SSIM = optim.lr_scheduler.ExponentialLR(optimizer_SSIM, gamma = 0.75, last_epoch = epoch_now-1, verbose=True)

### b.训练

In [None]:
net.cuda()
net.train()

for epoch in range(train_epoch_num):
    dataiter = iter(train_loader)
    loss_list_MSE = []
    loss_list_SSIM = []
    
    for batch_idx in trange(len(train_loader)):
        #初始化
        data = dataiter.next()
        LDCT_img, NDCT_img = data
        LDCT_img = LDCT_img.cuda()
        NDCT_img = NDCT_img.cuda()

        # 将梯度设置为0
        optimizer_MSE.zero_grad()
        optimizer_SSIM.zero_grad()             

        F_x, S_x = net(LDCT_img)
        F_x = F_x.cpu()
        
        #预测结果predicted_res和residuals通过之前定义的MSE计算损失
        MSE_loss = criterion_MSE(S_x, NDCT_img)
        NDCT_img = NDCT_img.cpu()
        SSIM_loss = criterion_SSIM(F_x, NDCT_img)

        # 误差反向传播MSE
        MSE_loss.backward()
        SSIM_loss.backward()

        # Adam优化权重
        optimizer_MSE.step()  
        # optimizer_SSIM.step()                 

        # 保存Loss
        loss_list_MSE.append(MSE_loss.item())
        loss_list_SSIM.append(SSIM_loss.item())

    mean_loss_MSE = sum(loss_list_MSE)/len(loss_list_MSE)
    mean_loss_SSIM = sum(loss_list_SSIM)/len(loss_list_SSIM)
    #保存Loss数据
    epochs_loss_MSE.append(mean_loss_MSE)
    epochs_loss_SSIM.append(mean_loss_SSIM)

    save_epochs_loss_MSE = np.array(epochs_loss_MSE)
    save_epochs_loss_SSIM = np.array(epochs_loss_SSIM)

    np.save(epochs_loss_path_MSE, save_epochs_loss_MSE)
    np.save(epochs_loss_path_SSIM, save_epochs_loss_SSIM)

    all_epoch_loss_MSE += loss_list_MSE
    all_epoch_loss_SSIM += loss_list_SSIM

    save_loss_MSE = np.array(all_epoch_loss_MSE)
    save_loss_SSIM = np.array(all_epoch_loss_SSIM)
    
    np.save(loss_path_MSE, save_loss_MSE)
    np.save(loss_path_SSIM, save_loss_SSIM)

    print('[epoch %d]: %.10f %.10f' % (epoch+epoch_now+1, mean_loss_MSE, mean_loss_SSIM))
    print('[epoch %d]: %.10f %.10f' % (epoch+epoch_now+1, mean_loss_MSE, mean_loss_MSE))
    plt.figure(figsize=(15,3))
    plt.subplot(231), plt.title('Epoch{}'.format(epoch+epoch_now+1)), plt.plot(loss_list_MSE, color='brown')
    plt.subplot(232), plt.title('All_Loss_MSE'), plt.plot(all_epoch_loss_MSE, color='brown')
    plt.subplot(233), plt.title('Epochs_Loss_MSE'), plt.plot(epochs_loss_MSE, color='brown')
    plt.subplot(234), plt.title('Epoch{}'.format(epoch+epoch_now+1)), plt.plot(loss_list_SSIM, color='teal')
    plt.subplot(235), plt.title('All_Loss_SSIM'), plt.plot(all_epoch_loss_SSIM, color='teal')
    plt.subplot(236), plt.title('Epochs_Loss_SSIM'), plt.plot(epochs_loss_SSIM, color='teal')
    plt.tight_layout(pad=0, w_pad=1.5, h_pad=0.5)
    plt.show()

    print('Saving epoch %d model ...' % (epoch+epoch_now+1))

    state = {'net.F': net.F.state_dict(), 'net.S': net.S.state_dict(), 'optimizer_MSE': optimizer_MSE.state_dict(), 'optimizer_SSIM': optimizer_SSIM.state_dict(),'epoch': epoch+epoch_now+1}
    if not os.path.isdir('./{}/checkpoint'.format(model_name)):
        os.mkdir('./{}/checkpoint'.format(model_name))
    torch.save(state, './{0}/checkpoint/{0}_epoch_{1}.ckpt'.format(model_name, epoch+epoch_now+1))
    #调整Lr
    scheduler_MSE.step()
    scheduler_SSIM.step()

plt.figure(figsize=(10,5))
plt.subplot(221), plt.title('All_Loss_MSE'), plt.plot(all_epoch_loss_MSE, color='brown')
plt.subplot(222), plt.title('Epochs_Loss_MSE'), plt.plot(epochs_loss_MSE, color='brown')
plt.subplot(223), plt.title('All_Loss_SSIM'), plt.plot(all_epoch_loss_SSIM, color='teal')
plt.subplot(224), plt.title('Epochs_Loss_SSIM'), plt.plot(epochs_loss_SSIM, color='teal')
plt.tight_layout(pad=1, w_pad=1.5, h_pad=0.5)
plt.savefig('./{}/LOSS-Epoch-{}.jpg'.format(model_name, epoch+epoch_now+1), dpi = 500, bbox_inches = 'tight', pad_inches = 0.25)
plt.show()

torch.save(net, './{0}/checkpoint/{0}_epoch_{1}.pt'.format(model_name, epoch+epoch_now+1))
print('Finished Training...')
for idx in range(len(epochs_loss_MSE)):
    print('Epoch{:3}: '.format(idx+1), epochs_loss_MSE[idx], epochs_loss_MSE[idx])###

### c.保存训练数据

In [None]:

res_dict = {}
res_dict['model_name'] = model_name
res_dict['loss_decay'] = 0.75
res_dict['batch_size'] = 16
res_dict['model_config'] = {'self.F':'self.S = DenseNet(growth_rate=16, block_config=(4, 8, 4), num_init_features=64, bn_size=4)',
'self.S':'self.S = DenseNet(growth_rate=16, block_config=(4, 8, 4), num_init_features=64, bn_size=4)'}
res_dict['net'] = str(net)
res_dict['epochs'] = '{}'.format(epoch+epoch_now+1)
res_dict['epochs_loss_MSE'] = epochs_loss_MSE
res_dict['epochs_loss_SSIM'] = epochs_loss_SSIM
res_dict['all_epoch_loss_MSE'] = all_epoch_loss_MSE
res_dict['all_epoch_loss_SSIM'] = all_epoch_loss_SSIM
file_path = r'./{0}/{0}-{1}.json'.format(model_name,epoch+epoch_now+1)
with open(file_path, 'w') as f:
    json.dump(res_dict, f, ensure_ascii = False)

## 5、模型测试
### a.测试集测试

In [None]:
LDCT_path = r'E:\NBIA\LDCT-and-Projection-data\L123\08-23-2018-75696\1.000000-Low Dose Images-07574'
NDCT_path = r'E:\NBIA\LDCT-and-Projection-data\L123\08-23-2018-75696\1.000000-Full dose images-67226'
test_epoch = 30
model_path = r'./{0}/checkpoint/{0}_epoch_{1}.pt'.format(model_name, test_epoch)
save_path = r'./{}/Test_Result'.format(model_name)
test = ModelTest(model_name, test_epoch, model_path, LDCT_path, NDCT_path, save_path, stage_num=2)
test.run()

### b.单图测试

In [None]:
test_epoch = 30
ND_path = r'E:\NBIA\L004\LDCT-and-Projection-data\L004\08-21-2018-84608\1.000000-Full dose images\1-44.dcm'
LD_path = r'E:\NBIA\L004\LDCT-and-Projection-data\L004\08-21-2018-84608\1.000000-Low Dose Images\1-44.dcm'
net_test = torch.load('./{0}/checkpoint/{0}_epoch_{1}.pt'.format(model_name, test_epoch))

net_test.eval()
net_test.cuda()

import pydicom.filereader as dcmreader
import pydicom.dataset as dcmdt

def Window(WW, WL):
    win_dict = {'vmin':WL-WW/2, 'vmax':WL+WW/2}
    return win_dict
def array2tensor(image_array):
    image_array = image_array[:, :, None]
    image_array = torch.from_numpy(image_array.transpose((2, 0, 1))).contiguous()
    image_array = torch.stack([image_array])
    image_tensor = image_array/1.0
    return image_tensor

LD_ds = dcmreader.dcmread(LD_path)
ND_ds = dcmreader.dcmread(ND_path)
LD_img = LD_ds.pixel_array.astype(np.int16)
ND_img = ND_ds.pixel_array.astype(np.int16)
my_totensor = dataset.My_ToTensor()
my_normalize = dataset.My_Normalize(0.1225, 0.1188)
loader = dataset.My_Compose([my_totensor, my_normalize])
Res_img = LD_img - ND_img
print(ND_img.shape)
plt.subplot(131), plt.title('ND'), plt.imshow(ND_img-1024, cmap = plt.cm.Greys_r, **Window(300, 40)), plt.axis('off')
plt.subplot(132), plt.title('LD'), plt.imshow(LD_img-1024, cmap = plt.cm.Greys_r, **Window(300, 40)), plt.axis('off')
plt.subplot(133), plt.title('Res'),  plt.imshow(Res_img, cmap = plt.cm.Greys_r), plt.axis('off')
plt.show()

LD_img = loader(LD_img)
ND_img = loader(ND_img)
LD_img = torch.stack([LD_img])
ND_img = ND_img.squeeze()
ND_img = ND_img.detach().numpy()

with torch.no_grad():
    LD_img = LD_img.cuda()
    F, S = net_test(LD_img)
    pre = S.cpu()
    LD_img = LD_img.cpu()
    LD_img = LD_img.squeeze()
    LD_img = LD_img.detach().numpy()
    pre = pre.squeeze()
    pre = pre.detach().numpy()

plt.subplot(231), plt.axis('off'), plt.title('NDCT'),             plt.imshow(((ND_img*0.1188)+0.1225)*4096 - 1024,
                                                                             cmap = plt.cm.Greys_r, **Window(300, 40))
plt.subplot(232), plt.axis('off'), plt.title('LDCT'),             plt.imshow(((LD_img*0.1188)+0.1225)*4096 - 1024,
                                                                             cmap = plt.cm.Greys_r, **Window(300, 40))
plt.subplot(233), plt.axis('off'), plt.title('Predicted'),        plt.imshow(((pre*0.1188)+0.1225)*4096 - 1024,
                                                                             cmap = plt.cm.Greys_r, **Window(300, 40))
plt.subplot(234), plt.axis('off'), plt.title('LDCT - NDCT'),      plt.imshow(LD_img - ND_img, cmap = plt.cm.Greys_r)
plt.subplot(235), plt.axis('off'), plt.title('LDCT - Predicted'), plt.imshow(LD_img - pre, cmap = plt.cm.Greys_r)
plt.subplot(236), plt.axis('off'), plt.title('Predicted - NDCT'), plt.imshow(pre - ND_img,
                                                                             cmap = plt.cm.Greys_r)
(0.1225, 0.1188)
plt.imsave('./{}/{}-1 ND.jpg'.format(model_name, test_epoch), ND_img, cmap = plt.cm.Greys_r)
plt.imsave('./{}/{}-2 LD.jpg'.format(model_name, test_epoch), LD_img, cmap = plt.cm.Greys_r)
plt.imsave('./{}/{}-3 pre.jpg'.format(model_name, test_epoch), pre, cmap = plt.cm.Greys_r)
plt.imsave('./{}/{}-4 pre_Noise.jpg'.format(model_name, test_epoch), LD_img - pre, cmap = plt.cm.Greys_r)
plt.imsave('./{}/{}-5 rea_Noise.jpg'.format(model_name, test_epoch), LD_img - ND_img, cmap = plt.cm.Greys_r)
plt.savefig('./{}/Epoch-{}.jpg'.format(model_name, test_epoch), dpi = 500, bbox_inches = 'tight', pad_inches = 0.25)

def mseloss(x, y):
    return np.sum((x-y)**2)/(512*512)

LDloss = mseloss(LD_img, ND_img)
preloss = mseloss(pre, ND_img)
print('LD:  ', LDloss,  10*np.log10(2*2/LDloss))
print('pre: ', preloss, 10*np.log10(2*2/preloss))

pn = (((pre*0.1188)+0.1225)*4096)
pn = np.clip(pn, 0, 4096)
pn = pn.astype(np.uint16)

ND_ds.SeriesDescription = 'NDCT'
LD_ds.SeriesDescription = 'LDCT'
ND_ds.save_as("./{}/NDCT.dcm".format(model_name))
LD_ds.save_as("./{}/LDCT.dcm".format(model_name))

ND_ds.SeriesDescription = 'Predicted'
ND_ds.PixelData = pn.tobytes()
ND_ds.Rows, ND_ds.Columns = pn.shape
ND_ds.save_as("./{}/Predicted.dcm".format(model_name))