Skip to content

Commit

Permalink
bug in conditional discrimination found and squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
akanimax committed Oct 17, 2018
1 parent ee7cf00 commit 2612aed
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pro_gan_pytorch/PRO_GAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,14 @@ def forward(self, x, height, alpha):
class ConditionalDiscriminator(th.nn.Module):
""" Discriminator of the GAN """

def __init__(self, height=7, feature_size=512,
def __init__(self, height=7, feature_size=512, embedding_size=4096,
compressed_latent_size=128, use_eql=True):
"""
constructor for the class
:param height: total height of the discriminator (Must be equal to the Generator depth)
:param feature_size: size of the deepest features extracted
(Must be equal to Generator latent_size)
:param embedding_size: size of the embedding for conditional discrimination
:param compressed_latent_size: size of the compressed version
:param use_eql: whether to use equalized learning rate
"""
Expand All @@ -214,9 +215,10 @@ def __init__(self, height=7, feature_size=512,
self.use_eql = use_eql
self.height = height
self.feature_size = feature_size
self.embedding_size = embedding_size
self.compressed_latent_size = compressed_latent_size

self.final_block = ConDisFinalBlock(self.feature_size, self.feature_size,
self.final_block = ConDisFinalBlock(self.feature_size, self.embedding_size,
self.compressed_latent_size, use_eql=self.use_eql)

# create a module list of the other required general convolution blocks
Expand Down Expand Up @@ -489,7 +491,8 @@ def __init__(self, embedding_size, depth=7, latent_size=512, compressed_latent_s

# Create the Generator and the Discriminator
self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
self.dis = ConditionalDiscriminator(depth, embedding_size, compressed_latent_size,
self.dis = ConditionalDiscriminator(depth, latent_size,
embedding_size, compressed_latent_size,
use_eql=use_eql).to(device)

# state of the object
Expand Down

0 comments on commit 2612aed

Please sign in to comment.