Skip to content

Commit

Permalink
correcting new code
Browse files Browse the repository at this point in the history
  • Loading branch information
akanimax committed Nov 19, 2018
1 parent 1c6ef91 commit ebe47f8
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions pro_gan_pytorch/CustomLayers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
""" Module containing custom layers """
import torch as th
import copy

import torch as th


# extending Conv2D and Deconv2D layers for equalized learning rate logic
class _equalized_conv2d(th.nn.Module):
Expand All @@ -26,7 +27,7 @@ def __init__(self, c_in, c_out, k_size, stride=1, pad=0, initializer='kaiming',
th.nn.init.xavier_normal_(self.conv.weight)

self.use_bias = bias

if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
self.scale = (th.mean(self.conv.weight.data ** 2)) ** 0.5
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(self, c_in, c_out, k_size, stride=1, pad=0, initializer='kaiming',
th.nn.init.xavier_normal_(self.deconv.weight)

self.use_bias = bias

if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
self.scale = (th.mean(self.deconv.weight.data ** 2)) ** 0.5
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(self, c_in, c_out, initializer='kaiming', bias=True):
th.nn.init.xavier_normal_(self.linear.weight)

self.use_bias = bias

if self.use_bias:
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
self.scale = (th.mean(self.linear.weight.data ** 2)) ** 0.5
Expand All @@ -134,17 +135,26 @@ def forward(self, x):
return x + self.bias.view(1, -1).expand_as(x)
return x

#----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# Pixelwise feature vector normalization.
# reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120

class PixelwiseNorm(nn.Module):
class PixelwiseNorm(th.nn.Module):
def __init__(self):
super(PixelwiseNorm, self).__init__()
def forward(self, x):
y = torch.mean(x.pow(2.), dim=1, keepdim=True) + 1e-8 # [N1HW]

def forward(self, x, alpha=1e-8):
"""
forward pass of the module
:param x: input activations volume
:param alpha: small number for numerical stability
:return: y => pixel normalized activations
"""
y = th.mean(x.pow(2.), dim=1, keepdim=True) + alpha # [N1HW]
return x.div(y.sqrt())


# ==========================================================
# Layers required for Building The generator and
# discriminator
Expand All @@ -159,7 +169,6 @@ def __init__(self, in_channels, use_eql):
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import LeakyReLU
from torch.nn.functional import local_response_norm

super(GenInitialBlock, self).__init__()

Expand Down Expand Up @@ -209,7 +218,6 @@ def __init__(self, in_channels, out_channels, use_eql):
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import LeakyReLU, Upsample
from torch.nn.functional import local_response_norm

super(GenGeneralConvBlock, self).__init__()

Expand All @@ -228,8 +236,7 @@ def __init__(self, in_channels, out_channels, use_eql):
padding=1, bias=True)

# Pixelwise feature vector normalization operation
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2,
beta=0.5, k=1e-8)
self.pixNorm = PixelwiseNorm()

# leaky_relu:
self.lrelu = LeakyReLU(0.2)
Expand Down Expand Up @@ -285,8 +292,8 @@ def __init__(self, averaging='all'):
self.n = int(self.averaging[5:])
else:
assert self.averaging in \
['all', 'flat', 'spatial', 'none', 'gpool'],\
'Invalid averaging mode %s' % self.averaging
['all', 'flat', 'spatial',
'none', 'gpool'], 'Invalid averaging mode %s' % self.averaging

# calculate the std_dev in such a way that it doesn't result in 0
# otherwise 0 norm operation's gradient is nan
Expand Down

0 comments on commit ebe47f8

Please sign in to comment.