In [1]:
#If you need a new visdom server for visualization, run this command in a terminal
# python -m visdom.server
#then navigate here http://localhost:8097

In [2]:
from models import loss
import argparse
import os
from util import util
import torch
import models
from models import networks
import data
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import torch.nn as nn


Using TensorFlow backend.


In [7]:
opt={#BASE OPTIONS
    #path to images (should have subfolders trainA, trainB, valA, valB, etc)
    "dataroot":"./datasets/ct_prostate_aligned/",

    
    #'name of the experiment. It decides where to store samples and models
    "name":"new",
    
    #gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU
    "gpu_ids":[0,1],
    
    # models are saved here
    
    "checkpoints_dir":"./checkpoints_seg",
    
    # chooses which model to use. [cycle_gan | pix2pix | test | colorization]
    "model":"pix2pix",
    
    # number of input image channels: 3 for RGB and 1 for grayscale
    "input_nc":3,
    
    # number of output image channels: 3 for RGB and 1 for grayscale
    "output_nc":3,
    
    # number of gen filters in the last conv layer
    "ngf":64,
    
    # number of discrim filters in the first conv layer
    "ndf":64,
    
    # specify discriminator architecture [basic | n_layers | pixel]. 
    # The basic model is a 70x70 PatchGAN.
    "netD":"basic",
    
    # specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]
    # this is set in the script, changing is likely bad 
    "netG":"unet_256",
    
    # only used if netD==n_layers
    "n_layers_D":3,
    
    # instance normalization or batch normalization [instance | batch | none]
    "norm":"batch",
    
    # network initialization [normal | xavier | kaiming | orthogonal]
    "init_type":"normal",
    
    # scaling factor for normal, xavier and orthogonal
    "init_gain":0.02,
    
    # no dropout for the generator
    "no_dropout":False,

    # chooses how datasets are loaded. [unaligned | aligned | single | colorization]
    # change to "unaligned" if data not already in single image pairs [AB]
    "dataset_mode":"aligned",
    
    # AtoB(building to facade) or BtoA(facade to building)
    "direction":"AtoB",
    
    # if true, takes images in order to make batches, otherwise takes them randomly
    "serial_batches":True,
    
    # number of threads for loading data
    "num_threads":4,
    
    # input batch size
    "batch_size":6,
    
    # scale images to this size
    "load_size":512,
    
    # then crop to this size
    "crop_size":512,
    
    # Maximum number of samples allowed per dataset. 
    # If the dataset directory contains more than max_dataset_size, only a subset is loaded.
    "max_dataset_size":float('inf'),
    
    # scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]
    "preprocess":"resize_and_crop",
    
    # if specified, do not flip the images for data augmentation
    "no_flip":"store_true",
    
    # display window size for both visdom and HTML
    "display_winsize":256,
    
    # which epoch to load? set to latest to use latest cached model
    "epoch":"1",
    
    # which iteration to load? if load_iter > 0, 
    # the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]'
    "load_iter":0,
    
    # if specified, print more debugging information
    "verbose":True,
    
    #customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}
    "suffix":"",
    
    #TRAIN OPTIONS
    "isTrain":True,
    
    # frequency of showing training results on screen
    "display_freq":6,
    
    # if positive, display all images in a single visdom web panel 
    # with certain number of images per row.
    "display_ncols":4,
    
    # window id of the web display
    "display_id":1,
    
    # visdom server of the web display
    "display_server":"http://localhost",
    
    # visdom display environment name (default is "main")
    "display_env":"main",
    
    # visdom port of the web display
    "display_port":8097,
    
    # frequency of saving training results to html
    "update_html_freq":100,
    
    # frequency of showing training results on console
    "print_freq":100,
    
    # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/
    "no_html":True,
    
    # network saving and loading parameters

    # frequency of saving the latest results
    "save_latest_freq":400,
    
    # frequency of saving checkpoints at the end of epochs
    "save_epoch_freq":1,
    
    # whether saves model by iteration
    "save_by_iter":True,
    
    # continue training: load the latest model
    "continue_train":False,
    
    # the starting epoch count, we save the model by <epoch_count>, 
    # <epoch_count>+<save_latest_freq>, ...
    "epoch_count":1,
    
    # train, val, test, etc
    "phase":"train",
    
    # number of iter at starting learning rate
    "niter":100,
    
    # number of iter to linearly decay learning rate to zero'
    "niter_decay":100,
    
    # momentum term of adam
    "beta1":0.5,
    
    # initial learning rate for adam
    "lr":0.0002,
    
    # the type of GAN objective. [vanilla| lsgan | wgangp|dice]. 
    # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.
    "gan_mode":'lsgan',
        
    #balance classes in dice loss
    "balance":True,
    
    # the size of image buffer that stores previously generated images
    # this is set in the script, changing is likely bad 
    "pool_size":0,
    
    # learning rate policy. [linear | step | plateau | cosine]
    "lr_policy":"step",
    
    # lambdas used for different models, only L1 is used in this case
    'lambda_A':10.0,
    'lambda_L1':0.0,
    'lambda_B':10.0, 
    'lambda_identity':0.5, 
    
    # multiply by a gamma every lr_decay_iters iterations
    "lr_decay_iters":25,
    
    #number of classes to predict
    "num_classes":6
}
class Map(dict):
    """
    Example:
    m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
    """
    def __init__(self, *args, **kwargs):
        super(Map, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v

        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(Map, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(Map, self).__delitem__(key)
        del self.__dict__[key]
opt = Map(opt)

In [None]:
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset)    # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)

model = create_model(opt)      # create a model given opt.model and other options
model.setup(opt)               # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
total_iters = 0                # the total number of training iterations
dice_array = []
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    epoch_start_time = time.time()  # timer for entire epoch
    iter_data_time = time.time()    # timer for data loading per iteration
    epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
    dice = []
    for i, data in enumerate(dataset):  # inner loop within one epoch
        iter_start_time = time.time()  # timer for computation per iteration
        if total_iters % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_iters += opt.batch_size
        epoch_iter += opt.batch_size
        model.set_input(data)         # unpack data from dataset and apply preprocessing
        
        model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
        
        
        if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
            
            #pred_fake = model.fake_B
            #pred_real = model.real_B
            #print(loss.SoftDiceLoss(model.opt.num_classes).forward(pred_fake,pred_real))
        
            save_result = total_iters % opt.update_html_freq == 0
            model.compute_visuals()
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
            losses = model.get_current_losses()
            t_comp = (time.time() - iter_start_time) / opt.batch_size
            visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
            #model.opt.gan_mode = 'dice'
            #model.criterionGAN == networks.GANLoss(gan_mode = opt.gan_mode,num_class = opt.num_classes,options = opt)
            
            #real_AB = torch.cat((model.criterionGAN.real_A, model.criterionGAN.real_B), 1)
            #pred_real = model.GANLoss.netD(real_AB)
            
            #fake_AB = torch.cat((model.criterionGAN.real_A, model.criterionGAN.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
            #pred_fake = model.GANLoss.netD(fake_AB.detach())
            #print(loss.SoftDiceLoss().forward(pred_fake,pred_real))
            
            
            if opt.display_id > 0:
                visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
    
        if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
            save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
            model.save_networks(save_suffix)

        iter_data_time = time.time()
    if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
        model.save_networks('latest')
        model.save_networks(epoch)

    #print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    #print('Dice Score:' +str((sum(dice) / len(dice))))

    #dice_array.append(str(epoch)+': '+str((sum(dice) / len(dice)))+'\n')

    model.update_learning_rate()                     # update learning rates at the end of every epoch.
    with open('newfile.txt','w+') as f:
            for dice_ in dice_array:
                f.write(dice_)

dataset [AlignedDataset] was created
The number of training images = 59




initialize network with normal
initialize network with normal
model [Pix2PixModel] was created
---------- Networks initialized -------------
DataParallel(
  (module): UnetGenerator(
    (model): UnetSkipConnectionBlock(
      (model): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): UnetSkipConnectionBlock(
          (model): Sequential(
            (0): LeakyReLU(negative_slope=0.2, inplace)
            (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
            (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (3): UnetSkipConnectionBlock(
              (model): Sequential(
                (0): LeakyReLU(negative_slope=0.2, inplace)
                (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    