In [27]:
# ! pip install -r ./requirements.txt


### 配对数据训练

In [28]:
import os
import numpy as np
from skimage.io import imread
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,random_split
import matplotlib.pyplot as plt
import itertools
import logging

from ldct.utils import *
from ldct.train import Trainer
from ldct.data import *
from ldct.loss import *
from ldct.net.generator import *
from ldct.net.red_cnn import *
from ldct.net.discriminator import Discriminator_Patch


In [None]:
fd_path = "../../LDCT_SLL_1000_200/train/fd/"
ld_path = "../../LDCT_SLL_1000_200/train/qd/"
train_set = LDCT_Dataset(fd_path,ld_path,crop_size=(256,256))

fd_path = "../../LDCT_SLL_1000_200/test/fd/"
ld_path = "../../LDCT_SLL_1000_200/test/qd/"
test_set = LDCT_Dataset(fd_path,ld_path,crop_size=(512,512))

len(train_set),len(test_set)

In [None]:
batch_size = 4
num_workers = 4
train_loader = DataLoader(train_set,batch_size,num_workers=num_workers)
test_loader = DataLoader(test_set,1)

img_ld,img_fd = next(iter(train_loader))
img_fd.shape, img_ld.shape

In [None]:
# Save Path
save_path = "./model/temp"
if not os.path.exists(save_path):
    os.mkdir(save_path)

# Logger
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)
logger.handlers=[]
fh = logging.FileHandler(os.path.join(save_path,'log.log'),"w")
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)        

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model
model = Generator_Unet_SEECA().to(device)
logger.info(f"param_count:\t{model_param_count(model)/1024/1024:.3f} M")

# Optim
optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)
# Lossess
criterion = torch.nn.MSELoss()

# Training Set
epoches=1000
patient = 5
rmse_best = np.Inf
psnr_best = 0
id_lamb = 5
gan_lamb = 10

# optim-scheduler
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(epoches).step)


e_count=patient
for e in range(epoches):
    loss_list = []
    for ldct, ndct in tqdm(train_loader):
        ldct, ndct = ldct.to(device), ndct.to(device)
        # Train
        pred = model(ldct)
        loss = criterion(pred,ndct)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
 
    # Loss Mean
    loss = np.mean(loss_list)
    # Update learning rates
    lr_scheduler.step()

    # Save models checkpoints
    rmse_list = []
    psnr_list = []
    for img_ld,img_fd in tqdm(test_loader):
        img_ld, img_fd = img_ld.to(device), img_fd.to(device)
        img_ld = model(img_ld).detach()
        img_ld, img_fd = img_ld.cpu(), img_fd.cpu()
        psnr_list.append(psnr(img_ld*255,img_fd*255).item())
    psnr_val = np.mean(psnr_list)

    if psnr_val > psnr_best:
        psnr_best = psnr_val
        e_count = patient
        torch.save(model, os.path.join(save_path,'best_GAB.pth'))

    info = f"{e}/{epoches}\tloss: {loss}\tPSNR/Best: {psnr_val}/{psnr_best}"
    print(info)
    logger.info(info)

    e_count-=1
    if e_count<1:
        break





