In [None]:
import torch
import argparse
import time
import pickle
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
import numpy as np
import os

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim



In [None]:
def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)
def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
opt = TrainOptions().parse()
opt.dataroot='/opt/data/private/生成器测试/datasets/petct102'
opt.model = 'pGAN'
opt.name = 'AttU_Net'

opt.which_model_netG = 'AttU_Net'

opt.lr  = 0.0002
opt.lr2 = 0.0002

opt.batchSize = 4
opt.which_direction =  'BtoA'
opt.lambda_A  = 100
opt.lambda_B = 0
opt.dataset_mode = 'aligned'
opt.pool_size = 0
opt.output_nc  = 1 
opt.input_nc  = 3
opt.loadSize =256
opt.niter  = 50
opt.niter_decay  = 50
opt.save_epoch_freq  = 25
opt.lambda_vgg  = 100 
opt.checkpoints_dir  = 'checkpoints/'
opt.pre_trained_transformer = 1
opt.gpu_ids = [0]
opt.norm="batch"
opt.lambda_gdl=0
opt.lambda_str=0

opt.display_server = "http://114.212.200.248"
opt.display_port = 25809

opt.training =True

if len(opt.gpu_ids) > 0:
    torch.cuda.set_device(opt.gpu_ids[0])

args = vars(opt)

In [None]:
args = vars(opt)

print('------------ Options -------------')
for k, v in sorted(args.items()):
    print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')

# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
    opt_file.write('------------ Options -------------\n')
    for k, v in sorted(args.items()):
        opt_file.write('%s: %s\n' % (str(k), str(v)))
    opt_file.write('-------------- End ----------------\n')



In [None]:

model = create_model(opt)


In [None]:
import pickle
train_size = 0
val_size = 0
number = 102
batch_size = 20
samples = np.arange(number-number % batch_size+1)
for i in range(1, len(samples), batch_size):
    with open('/opt/data/private/生成器测试/datasets/petct102/train/train_{}.pkl'.format(i), 'rb') as f:
        data_ct = pickle.load(f)
        print(data_ct.shape)
        train_size+= data_ct.shape[2]
print(train_size)

for i in range(1, len(samples), batch_size):
    with open('/opt/data/private/生成器测试/datasets/petct102/val/val_{}.pkl'.format(i), 'rb') as f:
        data_ct = pickle.load(f)
        print(data_ct.shape)
        val_size+= data_ct.shape[2]
print(val_size)


In [None]:
k = opt.input_nc 
dataset_size = train_size-(k-1)*(number//batch_size)
dataset_size_val = val_size-(k-1)*(number//batch_size)
samples = np.arange(number-number % batch_size+1)
def print_log(logger,message):
    print(message, flush=True)
    if logger:
        logger.write(str(message) + '\n')

##logger ##
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
logger = open(os.path.join(save_dir, 'log.txt'), 'w+')
print_log(logger,opt.name)
logger.close()

L1_avg=np.zeros([opt.niter + opt.niter_decay,dataset_size_val])      
psnr_avg=np.zeros([opt.niter + opt.niter_decay,dataset_size_val]) 

visualizer = Visualizer(opt)
total_steps = 0
#Starts training
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0

    # import tracemalloc 
    # tracemalloc.start()
    
    #Training step
    opt.phase='train'
    opt.batchSize = 4
    for j in range(1, len(samples), batch_size):
        data_loader = CreateDataLoader(opt,j)
        dataset = data_loader.load_data()
        print('Training images = %d' % dataset_size)  
    
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            model.set_input(data)
            model.optimize_parameters()
            
            #Save current images (real_A, real_B, fake_B)
            if  epoch_iter % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                #print(model.get_current_visuals())
                visualizer.display_current_results(model.get_current_visuals(), epoch , save_result)
            #Save current errors   
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)

                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
            #Save model based on the number of iterations
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
    
            iter_data_time = time.time()

    #         current_mem, peak_mem = tracemalloc.get_traced_memory()
    #         print(f"Current memory usage is {current_mem / 10**6}MB")
    #         print(f"Peak was {peak_mem / 10**6}MB")
    # tracemalloc.stop()
    n2=0
    if epoch % opt.save_epoch_freq == 0:
        logger = open(os.path.join(save_dir, 'log.txt'), 'a')
        print(opt.dataset_mode)
        opt.phase='val'
        opt.batchSize = 1
        for j in range(1, len(samples), batch_size):
            data_loader_val = CreateDataLoader(opt,j)
            dataset_val = data_loader_val.load_data()
            n1 = len(data_loader_val)
            print('Validation images = %d' % dataset_size) 
            for i, data_val in enumerate(dataset_val):  
		    
                model.set_input(data_val)      		    
                model.test()  

                fake_im=model.fake_B.cpu().data.numpy()       		    
                real_im=model.real_B.cpu().data.numpy()        		    
                real_im=real_im*0.5+0.5      		    
                fake_im=fake_im*0.5+0.5   		    

                
                if real_im.max() <= 0:
                    continue
                L1_avg[epoch-1,n2+i]=abs(fake_im-real_im).mean()
                psnr_avg[epoch-1,n2+i]=psnr(fake_im,real_im)

            n2 = n2 + n1
        l1_avg_loss = np.mean(L1_avg[epoch-1])               
        mean_psnr = np.mean(psnr_avg[epoch-1])               
        std_psnr = np.std(psnr_avg[epoch-1]) 

              
        print_log(logger,'Epoch %3d   l1_avg_loss: %.5f   mean_psnr: %.3f  std_psnr:%.3f  ' % \
        (epoch, l1_avg_loss, mean_psnr,std_psnr))   

        print_log(logger,'')
        logger.close()
        
        print('saving the model at the end of epoch %d, iters %d' %(epoch, total_steps))        		    
        model.save('latest')     		   
        model.save(epoch)




    print('End of epoch %d / %d \t Time Taken: %d sec' %
                (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    model.update_learning_rate()



In [None]:
print(model.netG)
