Skip to content

Commit

Permalink
minor bug-fixes and conditional module added
Browse files Browse the repository at this point in the history
  • Loading branch information
akanimax committed Jul 21, 2018
1 parent 0f3e62b commit ee7cf00
Show file tree
Hide file tree
Showing 4 changed files with 519 additions and 30 deletions.
124 changes: 108 additions & 16 deletions pro_gan_pytorch/CustomLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,23 @@ def forward(self, x):
class _equalized_linear(th.nn.Module):
""" Linear layer using equalized learning rate """

def __init__(self, c_in, c_out, initializer='kaiming'):
def __init__(self, c_in, c_out, initializer='kaiming', bias=True):
"""
Linear layer from pytorch extended to include equalized learning rate
:param c_in: number of input channels
:param c_out: number of output channels
:param initializer: initializer to be used: one of "kaiming" or "xavier"
:param bias: whether to use bias with the linear layer
"""
super(_equalized_linear, self).__init__()
self.linear = th.nn.Linear(c_in, c_out, bias=False)
self.linear = th.nn.Linear(c_in, c_out, bias=bias)
if initializer == 'kaiming':
th.nn.init.kaiming_normal_(self.linear.weight,
a=th.nn.init.calculate_gain('linear'))
elif initializer == 'xavier':
th.nn.init.xavier_normal_(self.linear.weight)

self.use_bias = bias
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
self.scale = (th.mean(self.linear.weight.data ** 2)) ** 0.5
self.linear.weight.data.copy_(self.linear.weight.data / self.scale)
Expand All @@ -124,7 +126,9 @@ def forward(self, x):
except RuntimeError:
dev_scale = self.scale
x = self.linear(x.mul(dev_scale))
return x + self.bias.view(1, -1).expand_as(x)
if self.use_bias:
return x + self.bias.view(1, -1).expand_as(x)
return x


# ==========================================================
Expand Down Expand Up @@ -156,7 +160,7 @@ def __init__(self, in_channels, use_eql):
self.conv_2 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True)

# Pixelwise feature vector normalization operation
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2 * x.shape[1],
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2,
beta=0.5, k=1e-8)

# leaky_relu:
Expand Down Expand Up @@ -211,7 +215,7 @@ def __init__(self, in_channels, out_channels, use_eql):
padding=1, bias=True)

# Pixelwise feature vector normalization operation
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2 * x.shape[1],
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2,
beta=0.5, k=1e-8)

# leaky_relu:
Expand All @@ -230,6 +234,22 @@ def forward(self, x):
return y


class EMA(th.nn.Module):
def __init__(self, mu):
super(EMA, self).__init__()
self.mu = mu
self.shadow = {}

def register(self, name, val):
self.shadow[name] = val.clone()

def forward(self, name, x):
assert name in self.shadow
new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
self.shadow[name] = new_average.clone()
return new_average


class MinibatchStdDev(th.nn.Module):
def __init__(self, averaging='all'):
"""
Expand Down Expand Up @@ -317,16 +337,16 @@ def __init__(self, in_channels, use_eql):
# declare the required modules for forward pass
self.batch_discriminator = MinibatchStdDev()
if use_eql:
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1)
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4))
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1))
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
else:
from torch.nn import Conv2d
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1)
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4))
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_3 = Conv2d(in_channels, 1, (1, 1))
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)

# leaky_relu:
self.lrelu = LeakyReLU(0.2)
Expand All @@ -345,7 +365,79 @@ def forward(self, x):
y = self.lrelu(self.conv_2(y))

# fully connected layer
y = self.lrelu(self.conv_3(y)) # final fully connected layer
y = self.conv_3(y) # This layer has linear activation

# flatten the output raw discriminator scores
return y.view(-1)


class ConDisFinalBlock(th.nn.Module):
""" Final block for the Conditional Discriminator """

def __init__(self, in_channels, in_latent_size, out_latent_size, use_eql):
"""
constructor of the class
:param in_channels: number of input channels
:param in_latent_size: size of the input latent vectors
:param out_latent_size: size of the transformed latent vectors
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import LeakyReLU

super(ConDisFinalBlock, self).__init__()

# declare the required modules for forward pass
self.batch_discriminator = MinibatchStdDev()
if use_eql:
self.compressor = _equalized_linear(c_in=in_latent_size, c_out=out_latent_size)
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels + out_latent_size,
in_channels, (1, 1), bias=True)
self.conv_3 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_4 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
else:
from torch.nn import Conv2d, Linear
self.compressor = Linear(in_features=in_latent_size,
out_features=out_latent_size, bias=True)
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels + out_latent_size,
in_channels, (1, 1), bias=True)
self.conv_3 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
# final conv layer emulates a fully connected layer
self.conv_4 = Conv2d(in_channels, 1, (1, 1), bias=True)

# leaky_relu:
self.lrelu = LeakyReLU(0.2)

def forward(self, x, latent_vector):
"""
forward pass of the FinalBlock
:param x: input
:param latent_vector: latent vector for conditional discrimination
:return: y => output
"""
# minibatch_std_dev layer
y = self.batch_discriminator(x)

# define the computations
y = self.lrelu(self.conv_1(y))
# apply the latent vector here:
compressed_latent_vector = self.compressor(latent_vector)
cat = th.unsqueeze(th.unsqueeze(compressed_latent_vector, -1), -1)
cat = cat.expand(
compressed_latent_vector.shape[0],
compressed_latent_vector.shape[1],
y.shape[2],
y.shape[3]
)
y = th.cat((y, cat), dim=1)

y = self.lrelu(self.conv_2(y))
y = self.lrelu(self.conv_3(y))

# fully connected layer
y = self.conv_4(y) # This layer has linear activation

# flatten the output raw discriminator scores
return y.view(-1)
Expand All @@ -366,12 +458,12 @@ def __init__(self, in_channels, out_channels, use_eql):
super(DisGeneralConvBlock, self).__init__()

if use_eql:
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1)
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1)
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1, bias=True)
else:
from torch.nn import Conv2d
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1)
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1)
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1, bias=True)

self.downSampler = AvgPool2d(2)

Expand Down
94 changes: 89 additions & 5 deletions pro_gan_pytorch/Losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@ def __init__(self, device, dis):
def dis_loss(self, real_samps, fake_samps, height, alpha):
raise NotImplementedError("dis_loss method has not been implemented")

def gen_loss(self, fake_samps, height, alpha):
def gen_loss(self, real_samps, fake_samps, height, alpha):
raise NotImplementedError("gen_loss method has not been implemented")


class ConditionalGANLoss:
""" Base class for all losses """

def __init__(self, device, dis):
self.device = device
self.dis = dis

def dis_loss(self, real_samps, fake_samps, latent_vector, height, alpha):
raise NotImplementedError("dis_loss method has not been implemented")

def gen_loss(self, real_samps, fake_samps, latent_vector, height, alpha):
raise NotImplementedError("gen_loss method has not been implemented")


Expand Down Expand Up @@ -60,7 +74,6 @@ def __gradient_penalty(self, real_samps, fake_samps,
return penalty

def dis_loss(self, real_samps, fake_samps, height, alpha):

# define the (Wasserstein) loss
fake_out = self.dis(fake_samps, height, alpha)
real_out = self.dis(real_samps, height, alpha)
Expand All @@ -76,7 +89,7 @@ def dis_loss(self, real_samps, fake_samps, height, alpha):

return loss

def gen_loss(self, fake_samps, height, alpha):
def gen_loss(self, _, fake_samps, height, alpha):
# calculate the WGAN loss for generator
loss = -th.mean(self.dis(fake_samps, height, alpha))

Expand All @@ -92,7 +105,7 @@ def dis_loss(self, real_samps, fake_samps, height, alpha):
return 0.5 * (((th.mean(self.dis(real_samps, height, alpha)) - 1) ** 2)
+ (th.mean(self.dis(fake_samps, height, alpha))) ** 2)

def gen_loss(self, fake_samps, height, alpha):
def gen_loss(self, _, fake_samps, height, alpha):
return 0.5 * ((th.mean(self.dis(fake_samps, height, alpha)) - 1) ** 2)


Expand All @@ -107,7 +120,78 @@ def dis_loss(self, real_samps, fake_samps, height, alpha):
fake_scores = th.mean(sigmoid(self.dis(fake_samps, height, alpha)))
return 0.5 * (((real_scores - 1) ** 2) + (fake_scores ** 2))

def gen_loss(self, fake_samps, height, alpha):
def gen_loss(self, _, fake_samps, height, alpha):
from torch.nn.functional import sigmoid
scores = th.mean(sigmoid(self.dis(fake_samps, height, alpha)))
return 0.5 * ((scores - 1) ** 2)


# =============================================================
# Conditional versions of the Losses:
# =============================================================

class CondWGAN_GP(ConditionalGANLoss):

def __init__(self, device, dis, drift=0.001, use_gp=False):
super().__init__(device, dis)
self.drift = drift
self.use_gp = use_gp

def __gradient_penalty(self, real_samps, fake_samps, latent_vector,
height, alpha, reg_lambda=10):
"""
private helper for calculating the gradient penalty
:param real_samps: real samples
:param fake_samps: fake samples
:param latent_vector: used for conditional loss calculation
:param height: current depth in the optimization
:param alpha: current alpha for fade-in
:param reg_lambda: regularisation lambda
:return: tensor (gradient penalty)
"""
from torch.autograd import grad

batch_size = real_samps.shape[0]

# generate random epsilon
epsilon = th.rand((batch_size, 1, 1, 1)).to(self.device)

# create the merge of both real and fake samples
merged = (epsilon * real_samps) + ((1 - epsilon) * fake_samps)

# forward pass
op = self.dis.forward(merged, latent_vector, height, alpha)

# obtain gradient of op wrt. merged
gradient = grad(outputs=op, inputs=merged, create_graph=True,
grad_outputs=th.ones_like(op),
retain_graph=True, only_inputs=True)[0]

# calculate the penalty using these gradients
penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()

# return the calculated penalty:
return penalty

def dis_loss(self, real_samps, fake_samps, latent_vector, height, alpha):
# define the (Wasserstein) loss
fake_out = self.dis(fake_samps, latent_vector, height, alpha)
real_out = self.dis(real_samps, latent_vector, height, alpha)

loss = (th.mean(fake_out) - th.mean(real_out)
+ (self.drift * th.mean(real_out ** 2)))

if self.use_gp:
# calculate the WGAN-GP (gradient penalty)
fake_samps.requires_grad = True # turn on gradients for penalty calculation
gp = self.__gradient_penalty(real_samps, fake_samps,
latent_vector, height, alpha)
loss += gp

return loss

def gen_loss(self, _, fake_samps, latent_vector, height, alpha):
# calculate the WGAN loss for generator
loss = -th.mean(self.dis(fake_samps, latent_vector, height, alpha))

return loss

0 comments on commit ee7cf00

Please sign in to comment.