Skip to content

Commit

Permalink
adding more features to the code
Browse files Browse the repository at this point in the history
  • Loading branch information
akanimax committed Nov 28, 2018
1 parent dbd0e39 commit 3f526db
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 107 deletions.
74 changes: 37 additions & 37 deletions pro_gan_pytorch/CustomLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,75 +376,75 @@ def forward(self, x):


class ConDisFinalBlock(th.nn.Module):
""" Final block for the Conditional Discriminator """
""" Final block for the Conditional Discriminator
Uses the Projection mechanism from the paper -> https://arxiv.org/pdf/1802.05637.pdf
"""

def __init__(self, in_channels, in_latent_size, out_latent_size, use_eql):
def __init__(self, in_channels, num_classes, use_eql):
"""
constructor of the class
:param in_channels: number of input channels
:param in_latent_size: size of the input latent vectors
:param out_latent_size: size of the transformed latent vectors
:param num_classes: number of classes for conditional discrimination
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import LeakyReLU
from torch.nn import LeakyReLU, Embedding

super(ConDisFinalBlock, self).__init__()

# declare the required modules for forward pass
self.batch_discriminator = MinibatchStdDev()
if use_eql:
self.compressor = _equalized_linear(c_in=in_latent_size, c_out=out_latent_size)
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1, bias=True)
self.conv_2 = _equalized_conv2d(in_channels + out_latent_size,
in_channels, (1, 1), bias=True)
self.conv_3 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True)
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4), bias=True)

# final conv layer emulates a fully connected layer
self.conv_4 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True)
else:
from torch.nn import Conv2d, Linear
self.compressor = Linear(in_features=in_latent_size,
out_features=out_latent_size, bias=True)
from torch.nn import Conv2d
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True)
self.conv_2 = Conv2d(in_channels + out_latent_size,
in_channels, (1, 1), bias=True)
self.conv_3 = Conv2d(in_channels, in_channels, (4, 4), bias=True)
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True)

# final conv layer emulates a fully connected layer
self.conv_4 = Conv2d(in_channels, 1, (1, 1), bias=True)
self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True)

# we also need an embedding matrix for the label vectors
self.label_embedder = Embedding(num_classes, in_channels)

# leaky_relu:
self.lrelu = LeakyReLU(0.2)

def forward(self, x, latent_vector):
def forward(self, x, labels):
"""
forward pass of the FinalBlock
:param x: input
:param latent_vector: latent vector for conditional discrimination
:param labels: samples' labels for conditional discrimination
Note that these are pure integer labels [Batch_size x 1]
:return: y => output
"""
# minibatch_std_dev layer
y = self.batch_discriminator(x)
y = self.batch_discriminator(x) # [B x C x 4 x 4]

# define the computations
y = self.lrelu(self.conv_1(y))
# apply the latent vector here:
compressed_latent_vector = self.compressor(latent_vector)
cat = th.unsqueeze(th.unsqueeze(compressed_latent_vector, -1), -1)
cat = cat.expand(
compressed_latent_vector.shape[0],
compressed_latent_vector.shape[1],
y.shape[2],
y.shape[3]
)
y = th.cat((y, cat), dim=1)
y = self.lrelu(self.conv_1(y)) # [B x C x 4 x 4]

y = self.lrelu(self.conv_2(y))
y = self.lrelu(self.conv_3(y))
# obtain the computed features
y = self.lrelu(self.conv_2(y)) # [B x C x 1 x 1]

# fully connected layer
y = self.conv_4(y) # This layer has linear activation
# embed the labels
labels = self.label_embedder(labels) # [B x C]

# flatten the output raw discriminator scores
return y.view(-1)
# compute the inner product with the label embeddings
y_ = th.squeeze(th.squeeze(y, dim=-1), dim=-1) # [B x C]
projection_scores = (y_ * labels).sum(dim=-1) # [B]

# normal discrimination score
y = self.lrelu(self.conv_3(y)) # This layer has linear activation

# calculate the total score
final_score = y.view(-1) + projection_scores

# return the output raw discriminator scores
return final_score


class DisGeneralConvBlock(th.nn.Module):
Expand Down

0 comments on commit 3f526db

Please sign in to comment.