In [2]:
import os
import copy
import time
import torch

import argparse

from lib.config.config import cfg_from_yaml, cfg, merge_dict_and_yaml
from lib.dataset.factory import get_dataset
from lib.model.factory import get_model


In [3]:

args = argparse.Namespace(
    data='LIDC256',
    tag='d2_multiview2500',
    dataroot='./data/LIDC-HDF5-256',
    dataset='train',
    valid_dataset='test',
    valid_datasetfile='./data/test.txt',
    datasetfile='./data/train.txt',
    ymlpath='./experiment/multiview2500/d2_multiview2500.yml',
    gpuid='0',
    dataset_class='align_ct_xray_views_std',
    model_class='MyMultiViewCTGAN',
    check_point=None,
    load_path=None,
    latest=False,
    verbose=False
)

args.gpu_ids = [int(i) for i in str(args.gpuid).split(',')]
args.epoch_count = 1

In [4]:
if args.ymlpath is not None:
  cfg_from_yaml(args.ymlpath)
# merge config with argparse
opt = copy.deepcopy(cfg)
opt = merge_dict_and_yaml(args.__dict__, opt)

In [5]:
datasetClass, augmentationClass, dataTestClass, collateClass = get_dataset(opt.dataset_class)
opt.data_augmentation = augmentationClass
dataset = datasetClass(opt)
print('DataSet is {}'.format(dataset.name))

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=int(opt.nThreads),
collate_fn=collateClass)

dataset_size = len(dataloader)
print('#training images = %d' % dataset_size)

DataSet is AlignDataSet
#training images = 916


In [7]:
# get model
gan_model = get_model(opt.model_class)()
print('Model --{}-- will be Used'.format(gan_model.name))
gan_model.init_process(opt)
total_steps, epoch_count = gan_model.setup(opt)

# set to train
gan_model.train()

Model --multiView_CTGAN-- will be Used
initialize network parameters with normal
initialize network parameters with normal
GAN loss: LSGAN
---------- Networks initialized -------------
[Network G] Total number of parameters : 61.929 M
[Network D] Total number of parameters : 11.055 M
-----------------------------------------------


CTGAN(
  (netG): DataParallel(
    (module): MultiView_UNetLike_DenseDimensionNet(
      (view1Model): UNetLike_DenseDimensionNet(
        (encoder): UNetLikeEncoder(
          (encode_layers): ModuleList(
            (0): Sequential(
              (0): Sequential(
                (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (1): ReLU(inplace=True)
                (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
                (3): Dense2DBlock(
                  (layers): ModuleList(
                    (0): Sequential(
                      (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                      (1): ReLU(inplace=True)
                      (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
                      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                      (4): ReLU

In [6]:
for data in dataloader:
    print(len(data))
    ct_data, xray_data, file_path = data
    xra1_1, xray_2 = xray_data
    print(f"ct_data_shape: {ct_data.shape}")
    print(f"xray_1_shape: {xra1_1.shape}")
    print(f"xray_2_shape: {xray_2.shape}")
    break


3
ct_data_shape: torch.Size([1, 128, 128, 128])
xray_1_shape: torch.Size([1, 1, 128, 128])
xray_2_shape: torch.Size([1, 1, 128, 128])


In [None]:

# visualizer
from lib.utils.visualizer import Visualizer
visualizer = Visualizer(log_dir=os.path.join(gan_model.save_root, 'train_log'))

In [11]:
# train discriminator more
dataloader_iter_for_discriminator = iter(dataloader)

# train
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()

    for epoch_i, data in enumerate(dataloader):
        iter_start_time = time.time()

        total_steps += 1

        gan_model.set_input(data)
        t0 = time.time()
        gan_model.optimize_parameters()
        t1 = time.time()
        
        # loss
        loss_dict = gan_model.get_current_losses()
        total_loss = visualizer.add_total_scalar('Total loss', loss_dict, step=total_steps)

        # if total_steps % opt.print_freq == 0:
        #     print('total step: {} timer: {:.4f} sec.'.format(total_steps, t1 - t0))
        #     print('epoch {}/{}, step{}:{} || total loss:{:.4f}'.format(epoch, opt.niter + opt.niter_decay,
        #                                                                 epoch_i, dataset_size, total_loss))
        #     print('||'.join(['{}: {:.4f}'.format(k, v) for k, v in loss_dict.items()]))
        #     # print('||'.join(['{}: {:.4f}'.format(k, v) for k, v in metrics_dict.items()]))
        #     print('')
            
        # if (opt.critic_times - 1) > 0:
        #     for critic_i in range(opt.critic_times - 1):
        #         try:
        #             data = next(dataloader_iter_for_discriminator)
        #             gan_model.set_input(data)
        #             gan_model.optimize_D()
        #         except:
        #             dataloader_iter_for_discriminator = iter(dataloader)
        # del(loss_dict)
            
        # # save model several epoch
        # if epoch % opt.save_epoch_freq == 0 and epoch >= opt.begin_save_epoch:
        #     print('saving the model at the end of epoch %d, iters %d' %
        #             (epoch, total_steps))
        #     gan_model.save_networks(epoch, total_steps)

        # print('End of epoch %d / %d \t Time Taken: %d sec' %
        #     (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
        
        # gan_model.update_learning_rate(epoch)




OrderedDict([('D', 0.9780977368354797), ('G', 0.5818601846694946), ('idt', 3.1900830268859863), ('map_m', 2.992755889892578)])
OrderedDict([('D', 11.9713716506958), ('G', 3.6097984313964844), ('idt', 2.963563919067383), ('map_m', 2.689861297607422)])
OrderedDict([('D', 3.7774829864501953), ('G', 1.310998558998108), ('idt', 2.424431324005127), ('map_m', 2.4099273681640625)])
OrderedDict([('D', 3.077674150466919), ('G', 0.9861790537834167), ('idt', 2.2349932193756104), ('map_m', 1.513641119003296)])
OrderedDict([('D', 2.075761556625366), ('G', 0.844458281993866), ('idt', 1.7143666744232178), ('map_m', 1.8175649642944336)])
OrderedDict([('D', 1.0637397766113281), ('G', 0.5958153605461121), ('idt', 1.499995231628418), ('map_m', 1.5326350927352905)])
OrderedDict([('D', 1.1872154474258423), ('G', 0.6672679781913757), ('idt', 1.2400293350219727), ('map_m', 1.178246259689331)])
OrderedDict([('D', 0.7084864974021912), ('G', 0.4193398654460907), ('idt', 0.9729255437850952), ('map_m', 1.346901893

KeyboardInterrupt: 