In [1]:
import time
import os
import numpy as np
import random
import functools

import torch
from torch.autograd import Variable
import torch.utils.data as data

from collections import OrderedDict
from subprocess import call
import math

from tqdm import tqdm

from torchvision import transforms as tf
from torchvision import models
from PIL import Image


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
os.chdir("D:/2021/2학기 수업/CV/Pix2Pix/hhd/")
print(os.getcwd())

D:\2021\2학기 수업\CV\Pix2Pix\hhd


In [4]:
data_root_dir = './T2R_Dataset' #
checkpoint_dir = './checkpoint'
name = 'thermal2RGB' #

gpu_ids = [0]

print_freq = 100
save_latest_freq = 1000
save_epoch_freq = 10
display_freq = 100

load_size = 512 #
crop_size = 256
nThreads = 0 #
no_flip = True #
serial_batches = True #
resize_or_crop = 'resize' #
gray_only = False
normalize = False # normalize input data
is_train = True
input_nc = 3
output_nc = 3
label_nc = 0
norm = 'instance'

no_lsgan = False
no_l1_loss = False
no_vgg_loss = False
no_ganFeat_loss = False
no_gan_loss = False

pool_size = 0
niter_decay = 50
niter = 100
lambda_feat = 10.0
ndf = 64
nef = 16
ngf = 64
n_layers_D = 3
num_D = 2

train_epochs = 10
batch_size = 4
lr = 0.0005

continue_train = False
load_pretrain = ''
which_epoch = 'latest' #

In [5]:
# 최소공배수
def lcm(a, b):
    return abs(a * b) / math.gcd(a, b) if a and b else 0

In [6]:
iter_path = os.path.join(checkpoint_dir, name, 'iter.txt')
print(iter_path)

./checkpoint\thermal2RGB\iter.txt


In [7]:
if continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter = ',', dtype = int)
    except:
        start_epoch, epoch_iter = 1, 0
else:
    start_epoch, epoch_iter = 1, 0


In [8]:
print_freq = lcm(print_freq, batch_size)

In [9]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

In [10]:
def make_dataset(in_dir):
    images = []
    assert os.path.isdir(in_dir)
    
    for root, _, fnames in sorted(os.walk(in_dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
                
    return images

In [11]:
def gen_ther_color_pil(ther_img_path):
    ther_img = np.asarray(Image.open(ther_img_path))
    if len(ther_img.shape) == 3:
        ther_img = ther_img[:,:,0]
    ther_img = np.stack([ther_img, ther_img, ther_img], -1)
    ther_img = Image.fromarray(ther_img)

    return ther_img

In [12]:
def get_params(size):
    w, h = size
    new_h = h
    new_w = w
    if resize_or_crop == 'resize_and_crop':
        new_h = new_w = load_size
    elif resize_or_crop == 'scale_width_and_crop':
        new_w = load_size
        new_h = load_size * h // w
    
    x = random.randint(0, np.maximum(0, new_w - crop_size))
    y = random.randint(0, np.maximum(0, new_h - crop_size))
    
    flip = random.random() > 0.5
    return {'crop_pos': (x, y), 'flip': flip}

In [13]:
def get_transform(params, normalize = False, method = Image.BICUBIC):
    transform_list = []
    if 'resize' in resize_or_crop:
        osize = [load_size, load_size]
        transform_list.append(tf.Resize(osize, method))
    elif 'scale_width' in resize_or_crop:
        transform_list.append(tf.Lambda(lambda img: __scale_width(img, load_size, method)))
        
    if 'crop' in resize_or_crop:
        transform_list.append(tf.Lambda(lambda img: __crop(img, params['crop_pos'], crop_size)))
        
    if is_train and not no_flip:
        transform_list.append(tf.Lambda(lambda img: __flip(img, params['flip'])))
        
    transform_list += [tf.ToTensor()]
    
    if normalize:
        transform_list += [tf.Normalize((0.5, 0.5, 0.5),
                                        (0.5, 0.5, 0.5))]
    return tf.Compose(transform_list)

In [14]:
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def initialize(self):
        pass

In [15]:
class CustomDataset(BaseDataset):
    def initialize(self, data_root_dir, is_train):
        super(CustomDataset, self).__init__()
        self.root = data_root_dir
        
        dir_A = '_A'
        self.dir_A = os.path.join(data_root_dir, 'train'+dir_A if is_train else 'test'+dir_A)
        self.A_paths = sorted(make_dataset(self.dir_A))
                
        self.A_paths = [path for path in self.A_paths if not 'ipynb_checkpoints' in path]
        
        if is_train:
            dir_B = '_B'
            self.dir_B = os.path.join(data_root_dir, 'train'+dir_B)
            self.B_paths = sorted(make_dataset(self.dir_B))
        
        self.dataset_size = len(self.A_paths)
        
        self.thm_gamma_low = 0.5
        self.thm_gamma_high = 1.5
        
        contrast_param = (0.3, 1)
        self.colorjitter1 = tf.ColorJitter(contrast = contrast_param)
    
    def __getitem__(self, index):
        A_path = self.A_paths[index]
        
        if gray_only:
            A = Image.open(A_path)
        else:
            A = gen_ther_color_pil(A_path)
        params = get_params(A.size)
        #print(A)
        transform_A = get_transform(params, normalize)
        A_tensor = transform_A(A)
        A_tensor = self.colorjitter1(A_tensor)
        
        if is_train:
            thm_random_gamma = np.random.uniform(self.thm_gamma_low, self.thm_gamma_high)
            A_tensor = A_tensor ** thm_random_gamma
            A_tensor = torch.clamp(A_tensor,0,1)
            
            B_path = self.B_paths[index]
            
            B = Image.open(B_path)
            
            transform_B = get_transform(params, normalize)
            B_tensor = transform_B(B)
            
            input_dict = {'label': A_tensor, 'image': B_tensor, 'path': A_path}
        else:
            input_dict = {'label': A_tensor, 'path': A_path}
        
        return input_dict
    
    def __len__(self):
        return len(self.A_paths) // batch_size * batch_size
        

In [16]:
class BaseDataLoader():
    def __init__(self):
        pass
    
    def initialize(self):
        pass

    def load_data():
        return None

In [17]:
def CreateDataset(data_root_dir, is_train):
    dataset = None
    dataset = CustomDataset()
    dataset.initialize(data_root_dir, is_train)
    return dataset

In [18]:
class CustomDatasetDataLoader(BaseDataLoader):    
    def initialize(self, data_root_dir, is_train):
        BaseDataLoader.initialize(self)
        self.dataset = CreateDataset(data_root_dir, is_train)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=not serial_batches,
            num_workers=int(nThreads))

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return len(self.dataset)

In [19]:
def CreateDataLoader(data_root_dir, is_train):
    data_loader = CustomDatasetDataLoader()
    data_loader.initialize(data_root_dir, is_train)
    return data_loader

In [20]:
data_loader = CreateDataLoader(data_root_dir, is_train)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
#print(train_loader_size)

#training images = 2660


In [21]:
class Vgg19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)        
        h_relu3 = self.slice3(h_relu2)        
        h_relu4 = self.slice4(h_relu3)        
        h_relu5 = self.slice5(h_relu4)                
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

In [22]:
class BaseModel(torch.nn.Module):
    def initialize(self, is_train):
        self.isTrain = is_train
        self.Tensor = torch.Tensor
        self.save_dir = os.path.join(checkpoint_dir, name)
    
    def set_input(self, input):
        self.input = input
    
    def forward(self):
        pass
    
    def test(self):
        pass
    
    def get_image_paths(self):
        pass
    
    def optimize_parameters(self):
        pass
    
    def get_current_visuals(self):
        return self.input
    
    def get_current_errors(self):
        return {}
    
    def save(self, label):
        pass
    
    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.to(device)
    
    def load_network(self, network, network_label, epoch_label, save_dir = ''):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        
        if not save_dir:
            save_dir = self.save_dir
        save_path = os.path.join(save_dir, save_filename)
        if not os.path.isfile(save_path):
            print('%s not exist!' % save_path)
            if network_label == 'G':
                raise('Generator must exist!')
        else:
            try:
                network.load_state_dict(torch.load(save_path))
            except:
                pretrained_dict = torch.load(save_path)
                model_dict = network.state_dict()
                try:
                    preparams = {}
                    for pre_k, model_k in zip(pretrained_dict.keys(), model_dict.keys()):
                        preparams[model_k] = pretrained_dict[pre_k]
                    
                    network.load_state_dict(preparams)
                except:
                    print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
                    for k, v in pretrained_dict.items():
                        if v.size() == model_dict[k].size():
                            model_dict[k] = v
                    
                    if sys.version_info >= (3,0):
                        not_initialized = set()
                    else:
                        from sets import Set
                        not_initialized = Set()
                    
                    for k, v in model_dict.items():
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                            not_initialized.add(k.split('.')[0])
                    
                    print(sorted(not_initialized))
                    network.load_state_dict(model_dict)
    
    def update_learning_rate():
        pass

In [23]:
def get_norm_layer(norm_type = 'instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(torch.nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

In [24]:
class SIGGRAPHGenerator(torch.nn.Module):
    def __init__(self, input_nc, output_nc, norm_layer = torch.nn.BatchNorm2d, use_noise = False, use_tanh = True):
        super(SIGGRAPHGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.use_noise = use_noise
        use_bias = True
        
        # Conv1
        model1=[torch.nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model1+=[torch.nn.ReLU(True),]
        model1+=[torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model1+=[torch.nn.ReLU(True),]
        model1+=[norm_layer(64),]

        # Conv2
        model2=[torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),]
        model2+=[torch.nn.ReLU(True),]
        model2+=[torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model2+=[torch.nn.ReLU(True),]
        model2+=[norm_layer(128),]

        # Conv3
        model3=[torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=use_bias),]
        model3+=[torch.nn.ReLU(True),]
        model3+=[torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model3+=[torch.nn.ReLU(True),]
        model3+=[torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model3+=[torch.nn.ReLU(True),]
        model3+=[norm_layer(256),]

        # Conv4
        model4=[torch.nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=use_bias),]
        model4+=[torch.nn.ReLU(True),]
        model4+=[torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model4+=[torch.nn.ReLU(True),]
        model4+=[torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model4+=[torch.nn.ReLU(True),]
        model4+=[norm_layer(512),]

        # Conv5
        model5=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model5+=[torch.nn.ReLU(True),]
        model5+=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model5+=[torch.nn.ReLU(True),]
        model5+=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model5+=[torch.nn.ReLU(True),]
        model5+=[norm_layer(512),]

        # Conv6
        model6=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model6+=[torch.nn.ReLU(True),]
        model6+=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model6+=[torch.nn.ReLU(True),]
        model6+=[torch.nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias),]
        model6+=[torch.nn.ReLU(True),]
        model6+=[norm_layer(512),]

        # Conv7
        model7=[torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model7+=[torch.nn.ReLU(True),]
        model7+=[torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model7+=[torch.nn.ReLU(True),]
        model7+=[torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model7+=[torch.nn.ReLU(True),]
        model7+=[norm_layer(512),]

        # Conv7
        model8up=[torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)]

        model3short8=[torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]

        model8=[torch.nn.ReLU(True),]
        model8+=[torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model8+=[torch.nn.ReLU(True),]
        model8+=[torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model8+=[torch.nn.ReLU(True),]
        model8+=[norm_layer(256),]

        # Conv9
        model9up=[torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),]

        model2short9=[torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]    

        model9=[torch.nn.ReLU(True),]
        model9+=[torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]
        model9+=[torch.nn.ReLU(True),]
        model9+=[norm_layer(128),]

        # Conv10
        model10up=[torch.nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias),]

        model1short10=[torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias),]

        model10=[torch.nn.ReLU(True),]
        model10+=[torch.nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias),]
        model10+=[torch.nn.LeakyReLU(negative_slope=.2),]

        # regression output
        model_out=[torch.nn.Conv2d(128, self.output_nc, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias),]
        if(use_tanh):
            model_out+=[torch.nn.Tanh()]

        self.model1 = torch.nn.Sequential(*model1)
        self.model2 = torch.nn.Sequential(*model2)
        self.model3 = torch.nn.Sequential(*model3)
        self.model4 = torch.nn.Sequential(*model4)
        self.model5 = torch.nn.Sequential(*model5)
        self.model6 = torch.nn.Sequential(*model6)
        self.model7 = torch.nn.Sequential(*model7)
        self.model8up = torch.nn.Sequential(*model8up)
        self.model8 = torch.nn.Sequential(*model8)
        self.model9up = torch.nn.Sequential(*model9up)
        self.model9 = torch.nn.Sequential(*model9)
        self.model10up = torch.nn.Sequential(*model10up)
        self.model10 = torch.nn.Sequential(*model10)
        self.model3short8 = torch.nn.Sequential(*model3short8)
        self.model2short9 = torch.nn.Sequential(*model2short9)
        self.model1short10 = torch.nn.Sequential(*model1short10)

        self.model_out = torch.nn.Sequential(*model_out)

    def forward(self, input_A):
        
        conv1_2 = self.model1(input_A)#(3,256,256) -> (64,256,256)
        conv2_2 = self.model2(conv1_2)#(64,256,256) -> (128,128,128)
        conv3_3 = self.model3(conv2_2)#(128,128,128) -> (256,64,64)
        conv4_3 = self.model4(conv3_3)#(256,64,64) -> (512,32,32)
        conv5_3 = self.model5(conv4_3)#(512,32,32) -> (512,32,32)
        conv6_3 = self.model6(conv5_3)#(512,32,32) -> (512,32,32)
        conv7_3 = self.model7(conv6_3)#(512,32,32) -> (512,32,32)

        conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) 
        conv8_3 = self.model8(conv8_up)

        conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
        conv9_3 = self.model9(conv9_up)
        conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
        conv10_2 = self.model10(conv10_up)
        
        out_reg = self.model_out(conv10_2)

        return out_reg

In [25]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [26]:
def define_G(input_nc, output_nc, norm = 'instance', use_noise = False):
    norm_layer = get_norm_layer(norm_type = norm)
    netG = SIGGRAPHGenerator(input_nc, output_nc, norm_layer, use_noise = use_noise, use_tanh = True)
    print(netG)
    #netG.to(device)
    #netG.apply(weights_init)
    return netG

In [27]:
class NLayerDiscriminator(torch.nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=torch.nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
        super(NLayerDiscriminator, self).__init__()
        self.getIntermFeat = getIntermFeat
        self.n_layers = n_layers

        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        sequence = [[torch.nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), torch.nn.LeakyReLU(0.2, True)]]

        nf = ndf
        for n in range(1, n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            sequence += [[
                torch.nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
                norm_layer(nf), torch.nn.LeakyReLU(0.2, True)
            ]]

        nf_prev = nf
        nf = min(nf * 2, 512)
        sequence += [[
            torch.nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
            norm_layer(nf),
            torch.nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[torch.nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

        if use_sigmoid:
            sequence += [[torch.nn.Sigmoid()]]

        if getIntermFeat:
            for n in range(len(sequence)):
                setattr(self, 'model'+str(n), torch.nn.Sequential(*sequence[n]))
        else:
            sequence_stream = []
            for n in range(len(sequence)):
                sequence_stream += sequence[n]
            self.model = torch.nn.Sequential(*sequence_stream)

    def forward(self, input):
        if self.getIntermFeat:
            res = [input]
            for n in range(self.n_layers+2):
                model = getattr(self, 'model'+str(n))
                res.append(model(res[-1]))
            return res[1:]
        else:
            return self.model(input)

In [28]:
class MultiscaleDiscriminator(torch.nn.Module):
    def __init__(self, input_nc, ndf = 64, n_layers = 3, norm_layer = torch.nn.BatchNorm2d,
                 use_sigmoid = False, num_D = 3, getIntermFeat = False):
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers
        self.getIntermFeat = getIntermFeat
        
        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            #netD = netD.to(device)
            if getIntermFeat:
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)
        self.downsample = torch.nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
        
    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[1:]
        else:
            return [model(input)]
    
    def forward(self, input):        
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))

            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result
        

In [29]:
def define_D(input_nc, ndf, n_layers_D, norm = 'instance', use_sigmoid = False, num_D = 1, getIntermFeat=False):
    norm_layer = get_norm_layer(norm_type = norm)
    netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
    print(netD)
    #netD.to(device)
    #netD.apply(weights_init)
    return netD

In [30]:
# Losses
class GANLoss(torch.nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.cuda.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = torch.nn.MSELoss().to(device)
        else:
            self.loss =  torch.nn.BCEWithLogitsLoss().to(device) 

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real)
                loss += self.loss(pred, target_tensor)
            return loss
        else:            
            target_tensor = self.get_target_tensor(input[-1], target_is_real)
            return self.loss(input[-1], target_tensor)

class VGGLoss(torch.nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()        
        self.vgg = Vgg19().to(device)
        self.criterion = torch.nn.L1Loss().to(device)
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        

    def forward(self, x, y):     
        if x.shape[1]==1:
            x=torch.cat([x,x,x],dim=1)
        if y.shape[1]==1:
            y=torch.cat([y,y,y],dim=1)
        
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())        
        return loss

class HuberLoss(torch.nn.Module):
    def __init__(self, delta=.01):
        super(HuberLoss, self).__init__()
        self.delta=delta

    def __call__(self, in0, in1):
        mask = torch.zeros_like(in0)
        mann = torch.abs(in0-in1)
        eucl = .5 * (mann**2)
        mask[...] = mann < self.delta

        # loss = eucl*mask + self.delta*(mann-.5*self.delta)*(1-mask)
        loss = eucl*mask/self.delta + (mann-.5*self.delta)*(1-mask)
        return torch.sum(loss,dim=1,keepdim=True)

In [31]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [32]:
class Pix2PixHDModel(BaseModel):
    def init_loss_filter(self, use_gan_loss, use_gan_feat_loss, use_vgg_loss, use_l1_loss):
        flags = (use_gan_loss, use_gan_feat_loss, use_vgg_loss, use_gan_loss, use_gan_loss, use_l1_loss)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake, g_l1):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake, g_l1),flags) if f]
        return loss_filter
    
    def initialize(self, is_train):
        BaseModel.initialize(self, is_train)
        if resize_or_crop != 'none' or not is_train:
            torch.backends.cudnn.benchmark = True
        self.isTrain = is_train
        
        # Generator network
        netG_input_nc = input_nc
        self.netG = define_G(netG_input_nc, output_nc, norm = norm, use_noise = False)
        self.netG.apply(weights_init)
        
        # Discriminator network
        if self.isTrain:
            use_sigmoid = no_lsgan
            netD_input_nc = input_nc + output_nc
            self.netD = define_D(netD_input_nc, ndf, n_layers_D, norm, use_sigmoid, 
                                          num_D, not no_ganFeat_loss)
            self.netD.apply(weights_init)
        
        ## Encoder
        if not self.isTrain or continue_train:
            pretrained_path = '' if not self.isTrain else load_pretrain
            self.load_network(self.netG, 'G', which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', which_epoch, pretrained_path)
        
        # set loss functions and optimizers
        if self.isTrain:
            self.fake_pool = ImagePool(pool_size)
            self.old_lr = lr
            
            self.loss_filter = self.init_loss_filter(not no_gan_loss, not no_ganFeat_loss, not no_vgg_loss, not no_l1_loss)
            self.criterionGAN = GANLoss(use_lsgan = not no_lsgan, tensor = torch.cuda.FloatTensor)
            self.criterionFeat = torch.nn.L1Loss().to(device)
            self.criterionSmoothL1 = HuberLoss(delta=1. / 110.0)
            
            if not no_vgg_loss:
                self.criterionVGG = VGGLoss()
            
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake','G_L1')
            
            params = list(self.netG.parameters())
            
            self.optimizer_G = torch.optim.Adam(params, lr = lr, betas = (0.5, 0.999))
            
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params, lr = lr, betas = (0.5, 0.999))
            
    def encode_input(self, label_map, real_image = None, infer = False):
        input_label = label_map.data
        input_label = Variable(input_label, volatile = infer)
        
        if real_image is not None:
            real_image = Variable(real_image.data)
            
        return input_label, real_image
    
    def discriminate(self, input_label, test_image, use_pool=False):

        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        if use_pool:            
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)
        
    def forward(self, label, image, infer=False):
        # Encode Inputs
        
        input_label, real_image = self.encode_input(label, image)

        # Fake Generation
        input_concat = input_label

        fake_image = self.netG.forward(input_concat)

        # Fake Detection and Loss
        loss_G_GAN=0
        loss_D_real=0
        loss_D_fake=0

        if not no_gan_loss:
            
            # Fake Detection and Loss
            pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
            loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

            # Real Detection and Loss        
            pred_real = self.discriminate(input_label, real_image)
            loss_D_real = self.criterionGAN(pred_real, True)

            # GAN loss (Fake Passability Loss)        
            pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
            loss_G_GAN = self.criterionGAN(pred_fake, True)
               
        loss_G_L1 = 0            
        loss_G_L1 = 10 * torch.mean(self.criterionSmoothL1(fake_image.type(torch.cuda.FloatTensor),
                                                            real_image.type(torch.cuda.FloatTensor)))
        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not no_ganFeat_loss:
            feat_weights = 4.0 / (n_layers_D + 1)
            D_weights = 1.0 / num_D
            for i in range(num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * lambda_feat
                   
        # VGG feature matching loss
        loss_G_VGG = 0
        if not no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image, real_image) * lambda_feat
        
        # Only return the fake_B image if necessary to save BW
        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, loss_G_L1), None if not infer else fake_image ]

    def inference(self, label, image=None):
        # Encode Inputs        
        image = Variable(image) if image is not None else None
        input_label, real_image = self.encode_input(label, image, infer=True)

        # Fake Generation
        input_concat = input_label        
           
        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(input_concat)
        else:
            with torch.no_grad():
                fake_image = self.netG.forward(input_concat)

        return fake_image

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch)
        self.save_network(self.netD, 'D', which_epoch)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
          
        self.optimizer_G = torch.optim.Adam(params, lr=lr, betas=(0.5, 0.999))
        
    def update_learning_rate(self):
        lrd = lr / niter_decay
        lr = self.old_lr - lrd        
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        self.old_lr = lr

class InferenceModel(Pix2PixHDModel):
    def forward(self, inp):
        label, inst = inp
        return self.inference(label, inst)

In [33]:
model = InferenceModel()
if is_train:
    model = Pix2PixHDModel()
model = model.to(device)
model.initialize(is_train)
for param in model.parameters():
    param.data = param.data.type(torch.cuda.FloatTensor)
    #print(param.data.type())
#if is_train:
    #model = torch.nn.DataParallel(model)
#print(model)

SIGGRAPHGenerator(
  (model1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (model2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (model3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): Re

In [34]:
optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D

In [35]:
total_steps = (start_epoch - 1) * dataset_size + epoch_iter
print(total_steps)

display_delta = total_steps % display_freq
print_delta = total_steps % print_freq
save_delta = total_steps % save_latest_freq

0


In [None]:
for epoch in tqdm(range(start_epoch, niter + niter_decay + 1)):
    epoch_start_time = time.time()
    if epoch != start_epoch:
        epoch_iter = epoch_iter % dataset_size
    for i, data in tqdm(enumerate(dataset, start=epoch_iter)):
        if total_steps % print_freq == print_delta:
            iter_start_time = time.time()
        total_steps += batch_size
        epoch_iter += batch_size
        
        save_fake = total_steps % display_freq == display_delta
        #print(data['label'].type(), data['image'].type())
        data['label'] = data['label'].to(device)
        data['image'] = data['image'].to(device)
        #print(data['label'].type(), data['image'].type())
        
        losses, generated = model(data['label'], data['image'], infer=save_fake)
        
        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
        loss_dict = dict(zip(model.loss_names, losses))
        
        if not no_gan_loss:
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict.get('G_L1',0)

            optimizer_G.zero_grad()
            loss_G.backward()          
            optimizer_G.step()      

            # update discriminator weights
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()
            
        if total_steps % print_freq == print_delta:
            errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}            
            t = (time.time() - iter_start_time) / print_freq
 
        if total_steps % save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
            
        if epoch_iter >= dataset_size:
            break
            
    iter_end_time = time.time()
    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, niter + niter_decay, time.time() - epoch_start_time))
    
    ### save model for this epoch
    if epoch % save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        
        model.save('latest')
        model.save(epoch)
        np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
        
    ### linearly decay learning rate after certain iterations
    if epoch > niter:
        model.update_learning_rate()

  0%|                                                                                          | 0/150 [00:00<?, ?it/s]