In [1]:
# ! pip install -r ./requirements.txt


### 非配对数据训练

In [2]:
import os
import numpy as np
from skimage.io import imread
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,random_split
import matplotlib.pyplot as plt
import itertools
import logging

from ldct.utils import *
from ldct.data import *
from ldct.loss import *
from ldct.net.generator import *
from ldct.net.discriminator import *


In [3]:
fd_path = "../../LDCT_SLL_1000_200/train/fd/"
ld_path = "../../LDCT_SLL_1000_200/train/qd/"
train_set = LDCT_CycleGan_Dataset(fd_path,ld_path,crop_size=(256,256))

fd_path = "../../LDCT_SLL_1000_200/test/fd/"
ld_path = "../../LDCT_SLL_1000_200/test/qd/"
test_set = LDCT_Dataset(fd_path,ld_path,crop_size=(512,512))

len(train_set),len(test_set)

(1000, 200)

In [4]:
batch_size = 4
num_workers = 4
train_loader = DataLoader(train_set,batch_size,num_workers=num_workers)
test_loader = DataLoader(test_set,1)

img_ld,img_fd = next(iter(train_loader))
img_fd.shape, img_ld.shape

(torch.Size([4, 1, 256, 256]), torch.Size([4, 1, 256, 256]))

In [5]:
# Save Path
save_path = "./model/temp"
if not os.path.exists(save_path):
    os.mkdir(save_path)

# Logger
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
logger.handlers=[]
fh = logging.FileHandler(os.path.join(save_path,'log.log'),"w")
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)        

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model
G_AB = Generator_Unet_SEECA().to(device)
G_BA = Generator_Unet_SEECA().to(device)
D_A = Discriminator_Patch().to(device)
D_B = Discriminator_Patch().to(device)
logger.info(f"G_param_count:\t{model_param_count(G_AB)/1024/1024:.3f} M")
logger.info(f"D_param_count:\t{model_param_count(D_A)/1024/1024:.3f} M")

# Optim
lr_G = 1e-4
lr_D = 2e-4
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()),
    lr=lr_G,
    betas=(0.5, 0.9),
)
optimizer_D = torch.optim.Adam(
    itertools.chain(D_A.parameters(), D_B.parameters()),
    lr=lr_D,
    betas=(0.9, 0.999),
)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
# criterion_identity = L1_Perc_Loss(x,y).to(device)

# Training Set
epoches=1000
patient = 5
rmse_best = np.Inf
psnr_best = 0
id_lamb = 5
gan_lamb = 10

# optim-scheduler
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(epoches).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(epoches).step)



e_count=patient
for e in range(epoches):
    G_loss_list = []
    D_A_loss_list = []
    D_B_loss_list = []
    for real_A, real_B in tqdm(train_loader):
        real_A, real_B = real_A.to(device), real_B.to(device)
        target_real = Variable(
            torch.ones((real_A.shape[0], 1, 30, 30)), requires_grad=False
            # torch.ones((real_A.shape[0], 1)), requires_grad=False
        ).to(device)
        target_fake = Variable(
            torch.zeros((real_A.shape[0], 1, 30, 30)), requires_grad=False
            # torch.zeros((real_A.shape[0], 1)), requires_grad=False
        ).to(device)

        # # Train G
        G_AB.train(),G_BA.train()
        D_A.eval(),D_B.eval()
        fake_A = G_BA(real_B)
        fake_B = G_AB(real_A)
        fake_AB = G_AB(fake_A)
        fake_BA = G_BA(fake_B)
        # Identity loss
        loss_identity_A = criterion_identity(fake_A, real_B)
        loss_identity_B = criterion_identity(fake_B, real_A)
        loss_identity = loss_identity_A+loss_identity_B
        # GAN loss
        loss_gan_A = criterion_GAN(D_A(fake_A),target_real)
        loss_gan_B = criterion_GAN(D_B(fake_B),target_real)
        loss_gan = loss_gan_A+loss_gan_B
        # Cycle Loss
        loss_cycle_A = criterion_cycle(fake_BA,real_A)
        loss_cycle_B = criterion_cycle(fake_AB,real_B)
        loss_cycle = loss_cycle_A+loss_cycle_B

        # Total loss
        loss_G = id_lamb*loss_identity + gan_lamb*loss_gan + loss_cycle
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        G_loss_list.append(loss_G.item())

        # # Train D
        D_A.train(),D_B.train()
        G_AB.eval(),G_BA.eval()
        fake_A = G_BA(real_B)
        fake_B = G_AB(real_A)
        # D loss
        loss_D_A = criterion_GAN(D_A(real_A),target_real)+criterion_GAN(D_A(fake_A),target_fake)
        loss_D_B = criterion_GAN(D_B(real_B),target_real)+criterion_GAN(D_B(fake_B),target_fake)
        loss_D = loss_D_A + loss_D_B
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        D_A_loss_list.append(loss_D_A.item())
        D_B_loss_list.append(loss_D_B.item())

    # Loss Mean
    G_loss, D_A_loss, D_B_loss = (
    np.mean(G_loss_list),
    np.mean(D_A_loss_list),
    np.mean(D_B_loss_list),
    )
    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D.step()

    # Save models checkpoints
    rmse_list = []
    psnr_list = []
    for img_ld,img_fd in tqdm(test_loader):
        img_ld, img_fd = img_ld.to(device), img_fd.to(device)
        img_ld = G_AB(img_ld).detach()
        img_ld, img_fd = img_ld.cpu(), img_fd.cpu()
        # rmse_list.append(rmse(img_ld*255,img_fd*255).item())
        psnr_list.append(psnr(img_ld*255,img_fd*255).item())
    # rmse_val = np.mean(rmse_list)
    psnr_val = np.mean(psnr_list)

    if psnr_val > psnr_best:
        psnr_best = psnr_val
        e_count = patient
        torch.save(G_AB, os.path.join(save_path,'best_GAB.pth'))
    torch.save(G_AB, os.path.join(save_path,'latest_GAB.pth'))


    info = f"{e}/{epoches}\tG: {G_loss}\tDA: {D_A_loss}\tDB: {D_B_loss}\tPSNR/Best: {psnr_val}/{psnr_best}"
    print(info)
    logger.info(info)

    e_count-=1
    if e_count<1:
        break




model = torch.load(os.path.join(save_path,"best_GAB.pth"),map_location=device)
model = model.eval()
psnr_list = []
ssim_list = []
gmsd_list = []
rmse_list = []
for img_ld,img_fd in tqdm(test_loader):
    img_ld, img_fd = img_ld.to(device), img_fd.to(device)
    img_ld = model(img_ld).detach()

    img_ld, img_fd = img_ld.cpu(), img_fd.cpu()
    psnr_list.append(psnr(img_ld*255,img_fd*255).item())
    ssim_list.append(ssim(img_ld*255,img_fd*255).item())
    gmsd_list.append(gmsd(img_ld*255,img_fd*255).item())
    rmse_list.append(rmse(img_ld*255,img_fd*255).item())
print(f"{np.mean(psnr_list):.3f}\t{np.mean(ssim_list):.3f}\t{np.mean(gmsd_list):.3f}\t{np.mean(rmse_list):.3f}")



