Skip to content

Commit

Permalink
finish discounting loss l1
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhaoyi-Yan committed Jan 21, 2019
1 parent b8b91f2 commit aed3638
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 45 deletions.
24 changes: 18 additions & 6 deletions models/shift_net/shiftnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def name(self):
def create_random_mask(self):
if self.mask_type == 'random':
if self.opt.mask_sub_type == 'fractal':
mask = util.create_walking_mask () # create an initial random mask.
mask = util.create_walking_mask() # create an initial random mask.

elif self.opt.mask_sub_type == 'rect':
mask = util.create_rand_mask ()
mask, rand_t, rand_l = util.create_rand_mask(self.opt)
self.rand_t = rand_t
self.rand_l = rand_l
return mask

elif self.opt.mask_sub_type == 'island':
mask = util.wrapper_gmask (self.opt)
mask = util.wrapper_gmask(self.opt)
return mask

def initialize(self, opt):
Expand Down Expand Up @@ -90,7 +93,7 @@ def initialize(self, opt):
# define loss functions
self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
self.criterionL1_mask =
self.criterionL1_mask = util.Discounted_L1(opt).to(self.device) # make weights/buffers transfer to the correct device

# initialize optimizers
self.schedulers = []
Expand Down Expand Up @@ -126,7 +129,7 @@ def set_input(self, input):
self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
elif self.opt.mask_type == 'random':
self.mask_global = self.create_random_mask().type_as(self.mask_global)
self.mask_global = self.create_random_mask().type_as(self.mask_global).view_as(self.mask_global)
else:
raise ValueError("Mask_type [%s] not recognized." % self.opt.mask_type)

Expand Down Expand Up @@ -252,7 +255,16 @@ def backward_G(self):
# If we change the mask as 'center with random position', then we can replacing loss_G_L1_m with 'Discounted L1'.
self.loss_G_L1, self.loss_G_L1_m = 0, 0
self.loss_G_L1 += self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
self.loss_G_L1_m += self.criterionL1(self.fake_B*self.mask_global.float(), self.real_B*self.mask_global.float())*self.opt.mask_weight
# calcuate mask construction loss
# When mask_type is 'center' or 'random_with_rect', we can add additonal mask region construction loss (traditional L1).
# Only when 'discounting_loss' is 1, then the mask region construction loss changes to 'discounting L1' instead of normal L1.
if self.opt.mask_type == 'center' and self.opt.mask_sub_type == 'rect':
mask_patch_fake = self.fake_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
mask_patch_real = self.real_B[:, :, self.rand_t:self.rand_t+self.opt.fineSize//2-2*self.opt.overlap, \
self.rand_l:self.rand_l+self.opt.fineSize//2-2*self.opt.overlap]
# Using Discounting L1 loss
self.loss_G_L1_m += self.criterionL1_mask(mask_patch_fake, mask_patch_real)*self.opt.mask_weight

if self.wgan_gp:
self.loss_G = self.loss_G_L1 + self.loss_G_L1_m - self.loss_G_GAN * self.opt.gan_weight
Expand Down
3 changes: 2 additions & 1 deletion options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def initialize(self, parser):
're_avg_gan (Relativistic average Standard GAN), '
're_avg_hinGan (Relativistic average HingeGAN), WARNING: wgan_gp should never be used here.')
parser.add_argument('--gan_weight', type=float, default=0.2, help='the weight of gan loss')
parser.add_argument('--mask_weight', type=float, default=5.0, help='the weight of mask part')
parser.add_argument('--mask_weight', type=float, default=400, help='the weight of mask part. Trying different mask_weight')
parser.add_argument('--discounting', type=int, default=1, help='the loss type of mask part, whether using discounting l1 loss or normal l1')
parser.add_argument('--use_spectral_norm', type=int, default=1, help='whether to add spectral norm in basic D')
parser.add_argument('--overlap', type=int, default=4, help='the overlap for center mask')
parser.add_argument('--show_flow', type=int, default=0, help='show the flow information. WARNING: set display_freq a large number as it is quite slow when showing flow')
Expand Down
62 changes: 24 additions & 38 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,18 @@ def create_gMask(gMask_opts, limit_cnt=1):
mask_global = mask.expand(1, 1, mask.size(0), mask.size(1))
return mask_global

def create_rand_mask(h=256, w=256, mask_size=64, overlap=0.25):
# Create a square mask with random position.
def create_rand_mask(opt):
h, w = opt.fineSize, opt.fineSize
mask = np.zeros((h, w))
positions = []
step = int(overlap * mask_size)
for y in range(0, h-mask_size+1, step):
for x in range(0, w-mask_size+1, step):
positions.append([y, x])
arr = np.array(range(len(positions)))
idx = np.random.choice(arr)
pos = positions[idx]
y, x = pos
mask[y:y + mask_size, x:x + mask_size] = 1
mask = mask[np.newaxis, ...][np.newaxis, ...]
return torch.ByteTensor(mask).cuda()
maxt = h - opt.overlap - h // 2
maxl = w - opt.overlap - w // 2
rand_t = np.random.randint(opt.overlap, maxt)
rand_l = np.random.randint(opt.overlap, maxl)

mask[rand_t:rand_t+opt.fineSize//2-2*opt.overlap, rand_l:rand_l+opt.fineSize//2-2*opt.overlap] = 1

return torch.ByteTensor(mask), rand_t, rand_l

action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
def random_walk(canvas, ini_x, ini_y, length):
Expand Down Expand Up @@ -460,34 +458,26 @@ def make_color_wheel():

#https://github.com/WonwoongCho/Generative-Inpainting-pytorch/blob/master/util.py#L77-L134
class Discounted_L1(nn.Module):
def __init__(self, opt, size_average=True, reduce=True):
def __init__(self, opt):
super(Discounted_L1, self).__init__()
self.reduce = reduce
self.discounting_mask = spatial_discounting_mask(128, 128, 0.9)
self.size_average = size_average
# Register discounting template as a buffer
self.register_buffer('discounting_mask', torch.tensor(spatial_discounting_mask(opt.fineSize//2 - opt.overlap * 2, opt.fineSize//2 - opt.overlap * 2, 0.9, opt.discounting)))
self.L1 = nn.L1Loss()

def forward(self, input, target):
self._assert_no_grad(target)
return self._pointwise_loss(lambda a, b: torch.abs(a - b), torch._C._nn.l1_loss,
input, target, self.discounting_mask, self.size_average, self.reduce)
input_tmp = input * self.discounting_mask
target_tmp = target * self.discounting_mask
return self.L1(input_tmp, target_tmp)


def _assert_no_grad(self, variable):
assert not variable.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these variables as volatile or not requiring gradients"

def _pointwise_loss(self, lambd, lambd_optimized, input, target, discounting_mask, size_average=True, reduce=True):
if target.requires_grad:
d = lambd(input, target)
d = d * discounting_mask
if not reduce:
return d
return torch.mean(d) if size_average else torch.sum(d)
else:
return lambd_optimized(input, target, size_average, reduce)


def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
def spatial_discounting_mask(mask_width, mask_height, discounting_gamma, discounting=1):
"""Generate spatial discounting mask constant.
Spatial discounting mask is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Expand All @@ -496,9 +486,9 @@ def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
"""
gamma = discounting_gamma
shape = [1, 1, mask_width, mask_height]
if True:
if discounting:
print('Use spatial discounting l1 loss.')
mask_values = np.ones((mask_width, mask_height))
mask_values = np.ones((mask_width, mask_height), dtype='float32')
for i in range(mask_width):
for j in range(mask_height):
mask_values[i, j] = max(
Expand All @@ -508,10 +498,6 @@ def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
mask_values = np.expand_dims(mask_values, 1)
mask_values = mask_values
else:
mask_values = np.ones(shape)
# it will be extended along the batch dimension suitably
mask_values = torch.from_numpy(mask_values).float()
# check for multi-gpu compatibility.
if torch.cuda.is_available():
mask_values = mask_values.cuda()
mask_values = np.ones(shape, dtype='float32')

return mask_values

0 comments on commit aed3638

Please sign in to comment.