In [1]:
from __future__ import print_function
import os, sys, gc, argparse, numpy as np

import torch
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

print("/content/gdrive/MyDrive/Datascience/ARProject/Tryon")

sDrive = "/content/gdrive/MyDrive/Datascience/ARProject/Tryon"

Mounted at /content/gdrive
/content/gdrive/MyDrive/Datascience/ARProject/Tryon


In [3]:
# set the default values to run the training model
def get_options():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataroot", type=str, default="data")
    parser.add_argument("datamode", default="train")
    parser.add_argument("stage", default="Stitch", help='Shape, Stitch, Refine')
    parser.add_argument("data_list", default="train_pairs.txt")
    parser.add_argument("thread", default="0") # number of workers/thread to use for loading data
    parser.add_argument('batch', type=str, default="1")  # batch size
    parser.add_argument('results', type=str, default='results/Shape', help='save results')
    parser.add_argument("epochs", type=str, default="45")
    parser.add_argument("input_channel", type=str, default="6")
    parser.add_argument("decay_epoch", type=str, default="10")
    parser.add_argument('learn_rate', type=str, default="0.0002", help='initial learning rate for adam')
    parser.add_argument("critic", type=str, default="10")  # Number of times after which to update Discriminator.
    parser.add_argument("display_count", type=str, default="1000")
    parser.add_argument("save_model", type=str, default="2")
    # set default values
    argv = ["", "Data", "train", "Refine", "train_pairs.txt", "0", "1", "results/"
            , "21", "6", "10", "0.0002", "10", "500", "2"]
    opt = parser.parse_args(argv[1:])
    print("arguments are set for training the model")
    return opt

In [4]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from imgaug import augmenters as iaa
from torch.autograd import Variable


# Initialize kernel weights to uniform. We are not using BatchNorm in final code.
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        torch.nn.init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)


# LambdaLR is use for Learning rate scheduling (Not used in main code).
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


class commonFunctions:
    def name(self):
        return 'commonFunctions'

    def __init__(self):
        super(commonFunctions, self).__init__()

    def createDir(self, path):
        if not os.path.exists(path):
            os.makedirs(path)

    def display_img(self, img, cmap=None):
        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111)
        ax.imshow(img, cmap)


class ImgAugTransform:
    def __init__(self):
        self.aug = iaa.Sequential([
            #         iaa.Scale((128, 128)),
            #         iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))),
            #         iaa.Fliplr(0.5),
            #         iaa.Affine(rotate=(-40, 40), mode='symmetric'),
            iaa.Affine(rotate=40, mode='symmetric')
            #         iaa.Sometimes(0.25,
            #                       iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
            #                                  iaa.CoarseDropout(0.1, size_percent=0.5)])),
            #         iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
        ])

    def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img)


class ImgAugTransformStitching:
    def __init__(self):
        sometimes = lambda aug: iaa.Sometimes(0.5, aug)

        self.aug = iaa.Sequential([
#         iaa.Scale((128, 128)),
#         iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))),
#         iaa.Fliplr(0.5),
        iaa.Affine(rotate=40, mode='symmetric'),
#             iaa.Affine( rotate = 20 , mode='symmetric')
#         iaa.Sometimes(0.25,
#                       iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
#                                  iaa.CoarseDropout(0.1, size_percent=0.5)])),
#         iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
        iaa.Affine(
            translate_percent={"x":0.2, "y": 0.1},
#             rotate=(-45, 45),
#             shear=(-16, 16),
#             order=[0, 1],
#             cval=(0, 255),
            mode='symmetric'
        )
    ])
    def __call__(self, img, img1, img2):
        img = np.array(img)
        img1 = np.array(img1)
        img2 = np.array(img2)

        return self.aug.augment_image(img), self.aug.augment_image(img1), self.aug.augment_image(img2)

class ImgAugTransformRefine:
    def __init__(self):
        sometimes = lambda aug: iaa.Sometimes(0.5, aug)

        self.aug = iaa.Sequential([
        iaa.Affine(
            translate_percent={"x":0.2, "y": 0.1},
            mode='symmetric'
        )
    ])
    def __call__(self, img, img1, img2):
        img = np.array(img)
        img1 = np.array(img1)
        img2 = np.array(img2)

        return self.aug.augment_image(img), self.aug.augment_image(img1), self.aug.augment_image(img2)

# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
# ReplayBuffer was first introduced in the above mentioned paper, It's effect mathematically has been supported in
# latest ICLR paper ProbGAN. Replay buffer uses previous data as prior for the Discriminator which it has seen already.
# Page 5 of the paper, just over Theory section.
# Hence we propose to maintain a subset of discriminators by subsampling the whole sequence of discriminators.

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


In [5]:
import numpy as np
import os
import os.path as osp
import matplotlib.pyplot as plt
import json
import random
import torch.utils.data as data
import torchvision.transforms as transforms

from skimage.filters import threshold_otsu
from PIL import Image
import torchvision.transforms.functional as TF
from PIL import ImageDraw

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")


class refineDataSetExtract():
    def __init__(self=None, height=0):
        super(refineDataSetExtract, self).__init__()
        # base setting
        path_ = os.getcwd()
        self.root = path_ + '/gdrive/MyDrive/Datascience/ARProject/Tryon/data/'
        self.datamode = 'train'  # train or test or self-define
        self.data_list = "train_pairs.txt"
        self.fine_height = height
        self.fine_width = 128
        self.radius = 3
        self.data_path = osp.join(self.root, self.datamode)
        self.transform = transforms.Compose(
            (transforms.Scale(self.fine_height), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)))

        self.transform_input = transforms.Compose(
            [ImgAugTransform(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        # load data list
        im_names = []
        c_names = []
        with open(osp.join(self.root, self.data_list), 'r') as f:
            for line in f.readlines():
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = c_names
        self.rotate = ImgAugTransformRefine()

    def name(self):
        return "refineDataSetExtract"

    def transformData(self, src, mask, target, cloth, wrap, diff, head):
        # Resize
        resize = transforms.Resize(size=(128, 128))
        src = resize(src)  # Source with missing cloth
        mask = resize(mask)  # mask of the missing cloth
        target = resize(target)  # target/ Ground truth
        cloth = resize(cloth)  # Cloth ground truth, how it should look before applying
        wrap = resize(wrap)  # skeleton
        diff = resize(diff)
        head = resize(head)

        src = TF.to_tensor(src)
        mask = TF.to_tensor(mask)
        target = TF.to_tensor(target)
        cloth = TF.to_tensor(cloth)
        wrap = TF.to_tensor(wrap)
        diff = TF.to_tensor(diff)
         #head = TF.to_tensor(head)

        src = TF.normalize(src, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        mask = TF.normalize(mask, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        target = TF.normalize(target, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        cloth = TF.normalize(cloth, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        wrap = TF.normalize(wrap, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        diff = TF.normalize(diff, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        return src, mask, target, cloth, wrap, diff, head

    def get_binary_from_img(self, image_name):
        loader2 = transforms.Compose([transforms.Resize((256, 192)), transforms.ToTensor()])
        """load image, returns cuda tensor"""
        image = Image.fromarray(np.uint8(image_name))
        image = loader2(image).float()
        better_contrast = image.permute(1, 2, 0).detach().cpu().numpy()
        better_contrast[better_contrast > 1] = 1
        #     print(lol.shape)

        thresh = threshold_otsu(better_contrast)
        binary = better_contrast > thresh
        return binary  # assumes that you're using GPU

    def get_binary(self, image_name):
        loader2 = transforms.Compose([transforms.Resize((256, 192)), transforms.ToTensor()])
        """load image, returns cuda tensor"""
        image = Image.open(image_name)
        image = loader2(image).float()
        better_contrast = image.permute(1, 2, 0).detach().cpu().numpy()
        better_contrast[better_contrast > 1] = 1
        #     print(lol.shape)

        thresh = threshold_otsu(better_contrast)
        binary = better_contrast > thresh
        return binary  # assumes that you're using GPU

    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]

        # person image
        im = plt.imread(osp.join(self.data_path, 'image', im_name))
        cm = plt.imread(osp.join(self.data_path, 'cloth', c_name))
        wrap = plt.imread(osp.join(self.data_path, 'image_shaped_cloth', im_name))
        diff = plt.imread(osp.join(self.data_path, 'changed_diff', im_name))
        #         im = self.transform(im) # [-1,1]

        # load parsing image

        parse_name = im_name.replace('.jpg', '.png')
        im_parse = Image.open(osp.join(self.data_path, 'image-parse', parse_name))
        parse_array = np.array(im_parse)
        #         parse_shape = (parse_array > 0).astype(np.float32)

        parse_head = (parse_array == 1).astype(np.float32) + \
                     (parse_array == 2).astype(np.float32) + \
                     (parse_array == 4).astype(np.float32) + \
                     (parse_array == 13).astype(np.float32)

        parse_cloth = (parse_array == 5).astype(np.float32) + (parse_array == 6).astype(np.float32) + (
                    parse_array == 7).astype(np.float32) + (parse_array == 9).astype(np.float32) + (
                                  parse_array == 15).astype(np.float32) + (parse_array == 3).astype(np.float32) + (
                                  parse_array == 14).astype(np.float32)

        pcm = self.get_binary_from_img(parse_cloth)
        phead = self.get_binary_from_img(parse_head)  # [0,1]
        im_h = im * phead - (1 - phead)  # [-1,1], fill 0 for other parts

        source = im * pcm
        source[source == 0] = 255
        mask = plt.imread(osp.join(self.data_path, 'nested_unet_msk', im_name))

        lol = self.get_binary(osp.join(self.data_path, 'nested_unet_msk', im_name))
        lol2 = source * (1 - lol)
        lol2[lol2 == 0] = 255

        lol3 = source * (lol)
        lol3[lol3 == 0] = 255

        input = Image.fromarray(np.uint8(lol2))
        mask = Image.fromarray(np.uint8(mask))
        style = Image.fromarray(np.uint8(lol3))
        target = Image.fromarray(np.uint8(source))
        cloth = Image.fromarray(np.uint8(cm))
        wrap = Image.fromarray(np.uint8(wrap))
        diff = Image.fromarray(np.uint8(diff))
        head = Image.fromarray(np.uint8(im_h))
        #         source = self.transform_input(input)  # [-1,1]
        #         mask = self.transform_input(mask)  # [-1,1]
        style_ = self.transform(style)
        cloth = self.transform(cloth)
        head = self.transform(head)
        #         targ = self.transform_input(style)
        #         skel = self.transform_input(one_map)

        resize = transforms.Resize(size=(128, 128))
        cloth = resize(cloth)  # Cloth ground truth, how it should look before applying
        style_ = resize(style_)

        source, mask, target, targ, wrap, diff, head = self.transformData(input, mask, target, style, wrap, diff, head)
        del lol3, lol2, pcm, im, parse_cloth, im_parse, lol
        return source, mask, style_, target, targ, wrap, diff, cloth, head  # , skel

    #

    def __len__(self):
        return len(self.im_names)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# ResNet module to process the incoming filters. We are using Instance Norm replacing traditional BatchNorm.
# BatchNorm doesn't plays any significant role, since our batch is very small, another thing we observed is
# that the feature maps don't face covariate shift in ResNet block as the dataset are very close to each other.
# removing Norm from ResNet block doesn't affects the model result.

class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, out_features, 3),
                      nn.InstanceNorm2d(out_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class ConvBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(ConvBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, out_features, 3),
                      nn.InstanceNorm2d(out_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x)


# Over the cause of GAN history we did infer that if we replace unknown region with Noise, then GANs can effectively
# generate the missing regions (effectively implies to generate something).
# We didn't test with different Noise, and their affects in detail.
class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, image, mask):
        #         pdb.set_trace()
        noise = torch.randn(1, 1, image.shape[2], image.shape[3])
        mask = mask[:, :1, :, :].repeat(1, image.shape[1], 1, 1)
        return image + self.weight * noise * mask


def swish(x):
    return x * F.sigmoid(x)


def get_mean_var(c):
    n_batch, n_ch, h, w = c.size()

    c_view = c.view(n_batch, n_ch, h * w)
    c_mean = c_view.mean(2)

    c_mean = c_mean.view(n_batch, n_ch, 1, 1).expand_as(c)
    c_var = c_view.var(2)
    c_var = c_var.view(n_batch, n_ch, 1, 1).expand_as(c)
    # c_var = c_var * (h * w - 1) / float(h * w)  # unbiased variance

    return c_mean, c_var


# model_ds downsamples the feature maps, we use stride = 2 to downsample feature maps instead of
# max pooling layer which is not learnable.
class model_ds(nn.Module):
    def __init__(self, in_features, out_features):
        super(model_ds, self).__init__()

        conv_block = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x)


# model_up Upsamples the feature maps again with a layer which is learnable, we didn't use any other method since
# nn.Upsample has no learnable weights, the other layer that we could have tried is sub-pixel which also learns to
# upsample / downsmaple.
class model_up(nn.Module):
    def __init__(self, in_features, out_features):
        super(model_up, self).__init__()

        conv_block = [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return self.conv_block(x)


class transform_layer(nn.Module):

    def __init__(self, input_nc, in_features, out_features):
        super(transform_layer, self).__init__()
        self.channels = in_features

        self.convblock = ConvBlock(in_features + in_features, out_features)
        self.up_conv = nn.Conv2d(in_features * 2, in_features, 3, 1, 1)
        self.down_conv = nn.Sequential(
            nn.Conv2d(64, in_features // 4, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(in_features // 4, in_features // 2, 1, 1),
            nn.ReLU(),
            nn.Conv2d(in_features // 2, in_features, 1, 1),
            nn.ReLU()
        )
        self.noise = NoiseInjection(in_features)

        self.convblock_ = ConvBlock(in_features + 64, out_features)

        self.vgg_block = nn.Sequential(
            nn.Conv2d(input_nc, 16, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 1, 1),
            nn.ReLU()
        )

    def forward(self, x, mask=None, style=None, mode='D'):
        #         pdb.set_trace()
        if mode == 'C':
            style = F.upsample(style, size=(x.shape[2], x.shape[2]), mode='bilinear')

            style = self.vgg_block(style)
            concat = torch.cat([x, style], 1)

            out = (self.convblock_(concat))
            return out, style
        else:
            mask = F.upsample(mask, size=(x.shape[2], x.shape[2]), mode='bilinear')
            x = self.noise(x, mask)
            #             style = F.upsample(style, size=(x.shape[2],x.shape[2]), mode='bilinear')

            style = self.down_conv(style)
            concat = torch.cat([x, style], 1)

            out = (self.convblock(concat) + style)
            return out


class transform_up_layer(nn.Module):
    def __init__(self, in_features, out_features, diff=False):
        super(transform_up_layer, self).__init__()
        self.channels = in_features

        if diff == True:
            self.convblock = ConvBlock(in_features * 2 + in_features, out_features)
        else:
            self.convblock = ConvBlock(in_features * 2, out_features)
        self.up_conv = nn.Sequential(
            nn.Conv2d(in_features * 2, in_features, 3, 1, 1),
            nn.ReLU()
        )

    def forward(self, x, y, mode="down"):

        y = self.up_conv(y)
        concat = torch.cat([x, y], 1)

        out = self.convblock(concat)

        #         out = self.adain(out,style)

        return out


class GeneratorCoarse(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=1):
        super(GeneratorCoarse, self).__init__()
        in_features = 64

        self.model_input_cloth = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc + 1, in_features, 7),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True)
        )

        self.block128 = nn.Sequential(
            ResidualBlock(in_features, in_features)
        )
        self.block128_transform = transform_layer(input_nc, in_features, in_features)

        self.block64 = nn.Sequential(
            model_ds(in_features, in_features * 2),
            ResidualBlock(in_features * 2, in_features * 2)
        )
        self.block64_transform = transform_layer(input_nc, in_features * 2, in_features * 2)

        self.block32 = nn.Sequential(
            model_ds(in_features * 2, in_features * 4),
            ResidualBlock(in_features * 4, in_features * 4)
        )
        self.block32_transform = transform_layer(input_nc, in_features * 4, in_features * 4)

        self.block16 = nn.Sequential(
            model_ds(in_features * 4, in_features * 8),
            ResidualBlock(in_features * 8, in_features * 8)
        )
        self.block16_transform = transform_layer(input_nc, in_features * 8, in_features * 8)
        self.block8 = nn.Sequential(
            model_ds(in_features * 8, in_features * 8),
            ResidualBlock(in_features * 8, in_features * 8)
        )
        self.block8_transform = transform_layer(input_nc, in_features * 8, in_features * 8)
        self.block4 = nn.Sequential(
            model_ds(in_features * 8, in_features * 8),
            ResidualBlock(in_features * 8, in_features * 8)
        )
        self.block4_transform = transform_layer(input_nc, in_features * 8, in_features * 8)

        self.block4_up = nn.Sequential(
            nn.Conv2d(in_features * 8, in_features * 4, 3, 1, 1),
            ResidualBlock(in_features * 4, in_features * 4)
        )
        self.block4_up_transform = transform_up_layer(in_features * 4, in_features * 8)

        self.block8_up = nn.Sequential(
            model_up(in_features * 8, in_features * 4),
            ResidualBlock(in_features * 4, in_features * 4)
        )
        self.block8_up_transform = transform_up_layer(in_features * 4, in_features * 8)

        self.block16_up = nn.Sequential(
            model_up(in_features * 8, in_features * 4),
            ResidualBlock(in_features * 4, in_features * 4)
        )
        self.block16_up_transform = transform_up_layer(in_features * 4, in_features * 8)

        self.block32_up = nn.Sequential(
            model_up(in_features * 8, in_features * 4),
            ResidualBlock(in_features * 4, in_features * 4)
        )
        self.block32_up_transform = transform_up_layer(in_features * 2, in_features * 4, True)

        self.block64_up = nn.Sequential(
            model_up(in_features * 4, in_features * 2),
            ResidualBlock(in_features * 2, in_features * 2)
        )
        self.block64_up_transform = transform_up_layer(in_features, in_features * 2, True)

        self.block128_up = nn.Sequential(
            model_up(in_features * 2, in_features),
            ResidualBlock(in_features, in_features)
        )

        self.block128_up_transform = transform_up_layer(in_features // 2, in_features, True)

        self.model_output = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_features, output_nc, 7),
            nn.Tanh()
        )
    def _conv_layer_set(self, in_c, out_c):
        model_input = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_c, out_c, 7, padding=0),
            nn.InstanceNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    # forward function
    def forward(self, src, *input):
        in_features = 64
        conds = []
        for cond in input:
            conds.append(cond)
        conds.append(src)

        style = torch.cat(conds, 1)
        y = torch.cat([torch.randn(1, 1, src.shape[2], src.shape[3]), style], 1)

        y = self.model_input_cloth(y)

        y128 = self.block128(y)
        y128, s_128 = self.block128_transform(x=y128, style=style, mode="C")

        y64 = self.block64(y128)
        y64, s_64 = self.block64_transform(x=y64, style=style, mode="C")

        y32 = self.block32(y64)
        y32, s_32 = self.block32_transform(x=y32, style=style, mode="C")

        y16 = self.block16(y32)
        y16, s_16 = self.block16_transform(x=y16, style=style, mode="C")

        y8 = self.block8(y16)
        y8, s_8 = self.block8_transform(x=y8, style=style, mode="C")

        y4 = self.block4(y8)
        y4, s_4 = self.block4_transform(x=y4, style=style, mode="C")

        ############## Decoder #######################

        y4u = self.block4_up(y4)
        y4u = self.block4_up_transform(y4u, y4)

        y8u = self.block8_up(y4u)
        y8u = self.block8_up_transform(y8u, y8)

        y16u = self.block16_up(y8u)
        y16u = self.block16_up_transform(y16u, y16)

        y32u = self.block32_up(y16u)

        y64u = self.block64_up(y32u)

        y128u = self.block128_up(y64u)

        out = self.model_output(y128u)

        return out, s_128, s_64, s_32, s_16, s_8, s_4


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1)
        self.bn2 = nn.InstanceNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.bn3 = nn.InstanceNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.bn4 = nn.InstanceNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.bn5 = nn.InstanceNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1)
        self.bn6 = nn.InstanceNorm2d(256)
        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.bn7 = nn.InstanceNorm2d(512)
        self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1)
        self.bn8 = nn.InstanceNorm2d(512)

        # Replaced original paper FC layers with FCN
        self.conv9 = nn.Conv2d(512, 1, 1, stride=1, padding=1)

    def forward(self, x):
        x = swish(self.conv1(x))

        x = swish(self.bn2(self.conv2(x)))
        x = swish(self.bn3(self.conv3(x)))
        x = swish(self.bn4(self.conv4(x)))
        x = swish(self.bn5(self.conv5(x)))
        x = swish(self.bn6(self.conv6(x)))
        x = swish(self.bn7(self.conv7(x)))
        x = swish(self.bn8(self.conv8(x)))

        x = self.conv9(x)
        return F.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)


In [7]:
def trainRefineModel(opt,netG, netD):
    print("training the model for different model types: %s" % (opt.stage))
    dataset = refineDataSetExtract(128)
    train_loader = DataLoader(dataset,
                              batch_size=int(opt.batch),
                              shuffle=False,
                              num_workers=int(opt.thread),
                              drop_last=True, pin_memory=True)

    epoch = 0
    
    
    n_epochs = int(opt.epochs)
    decay_epoch = int(opt.decay_epoch)
    batchSize = int(opt.batch)
    size = 128
    input_nc = int(opt.input_channel)
    output_nc = 3
    lr = float(opt.learn_rate)
    nRow = 4

    criterion_GAN = torch.nn.MSELoss()
    criterion_identity = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                       lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D,
                                                       lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.FloatTensor
    input_A = Tensor(batchSize, input_nc, size, size)

    target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

    fake_buffer = ReplayBuffer()
    print(n_epochs)
    for epoch in range(0, n_epochs):
        gc.collect()
        print(epoch)
        Source = iter(train_loader)
        avg_loss_g = 0
        avg_loss_d = 0
        for i in range(0, len(train_loader)):
            netG.train()
            target_real = Variable(torch.ones(1, 1), requires_grad=False)
            target_fake = Variable(torch.zeros(1, 1), requires_grad=False)
            optimizer_G.zero_grad()

            #src, mask, style_img, target, gt_cloth, skel, cloth = Source.next()
            #src, mask, style_img, target, gt_cloth, skel, cloth = Variable(src), Variable(mask), Variable(style_img), Variable(target)\
             #   , Variable(gt_cloth), Variable(skel), Variable(cloth)

            src, mask, style_img, target, gt_cloth, wrap, diff, cloth, head = Source.next()
            src, mask, style_img, target, gt_cloth, wrap, diff, cloth, head = Variable(src), Variable(mask), Variable(style_img), Variable(target), Variable(gt_cloth), Variable(wrap), Variable(diff), Variable(cloth), Variable(head)


            #print(src.shape)
            #print(mask.shape)
            #print(style_img.shape)
            #print(target.shape)
            #print(gt_cloth.shape)
            #print(wrap.shape)
            #print(diff.shape)
            #print(cloth.shape)

            gen_targ, _, _, _, _, _, _ = netG(diff, wrap)
            pred_fake = netD(gen_targ)

            loss_GAN = 10 * criterion_GAN(pred_fake, target_real) + 10 * criterion_identity(gen_targ, target)

            loss_G = loss_GAN
            loss_G.backward()

            optimizer_G.step()
            #############################################
            optimizer_D.zero_grad()

            pred_real = netD(target)

            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            gen_targ = fake_buffer.push_and_pop(gen_targ)
            pred_fake = netD(gen_targ.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()

            if (i + 1) % int(opt.critic) == 0:
                optimizer_D.step()

            avg_loss_g = (avg_loss_g + loss_G) / (i + 1)
            avg_loss_d = (avg_loss_d + loss_D) / (i + 1)

            path_ = os.getcwd()
            root = path_ + '/gdrive/MyDrive/Datascience/ARProject/Tryon'

            if (i + 1) % 50 == 0:
                print("Epoch: (%3d) (%5d/%5d) Loss: (%0.0003f) (%0.0003f)" % (
                epoch, i + 1, len(train_loader), avg_loss_g * 1000, avg_loss_d * 1000))

            #if (i + 1) % int(opt.display_count) == 0:
                #pic = (torch.cat([wrap,diff,gen_targ, target], dim=0).data + 1) / 2.0
                pic = (torch.cat([wrap, diff, gen_targ, target, head], dim=0).data + 1) / 2.0

                pic1 = torch.cat([target], 0)  
                pic2 = torch.cat([head], 0)          

                save_dir = "{}/{}{}".format(root, opt.results, opt.stage)
                if not os.path.exists(save_dir):
                  os.makedirs(save_dir)
                save_image(pic, '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, i + 1, len(train_loader)), nrow=nRow)
                save_image(pic1, '%s/Epoch_Final_(%d)_(%dof%d).jpg' % (save_dir, epoch, i + 1, len(train_loader)))
                save_image(pic2, '%s/Epoch_Head_(%d)_(%dof%d).jpg' % (save_dir, epoch, i + 1, len(train_loader)))


            if (epoch + 1) % int(opt.save_model) == 0:
                save_dir = "{}/{}{}".format(root, opt.results, opt.stage)
                if not os.path.exists(save_dir):
                  os.makedirs(save_dir)
                torch.save(netG.state_dict(), '{}/Gan_{}.pth'.format(save_dir, epoch))

            # Update learning rates
            lr_scheduler_G.step()
            lr_scheduler_D.step()

            if i == 500 :
                break;

        #if(epoch ==5):
         # break;
    print('traing refine complete')

In [8]:
# define main function to start executing the training model
def main():
    trn_options = get_options()
    print(trn_options)
    print("Model training started")

    if not os.path.exists(trn_options.results):
        os.makedirs(trn_options.results)
    
    if trn_options.stage == "Stitch":
        netG = GeneratorCoarse(9, 3)
    else:
        netG = GeneratorCoarse(int(trn_options.input_channel), 3)

    netD = Discriminator()

    #intialize the weight for the model
    netG.apply(weights_init_normal)
    netD.apply(weights_init_normal)

    if trn_options.stage == "Shape":
        print("Training started for %s" % (trn_options.stage))
        trainShapeModel(trn_options, netG, netD)

        print("Training completed for %s " % (trn_options.stage))
    elif trn_options.stage == "Stitch":
        print("Training started for %s" % (trn_options.stage))
        trainStitchModel(trn_options, netG, netD)

        print("Training completed for %s " % (trn_options.stage))
    elif trn_options.stage == "Refine":
        print("Training started for %s" % (trn_options.stage))
        trainRefineModel(trn_options, netG, netD)

        print("Training completed for %s " % (trn_options.stage))
    else:
        print("Please mention the Stage from [Shape, Stitch, Refine]")


    print('Finished training ')

    sys.exit("Please mention the next Stage from [Shape, Stitch, Refine]")


In [None]:
if __name__ == "__main__":
    main()

arguments are set for training the model
Namespace(batch='1', critic='10', data_list='train_pairs.txt', datamode='train', dataroot='Data', decay_epoch='10', display_count='500', epochs='21', input_channel='6', learn_rate='0.0002', results='results/', save_model='2', stage='Refine', thread='0')
Model training started
Training started for Refine
training the model for different model types: Refine
21
0
