Skip to content

Commit

Permalink
yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Feb 24, 2021
1 parent 7e90f80 commit 66317f4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
21 changes: 15 additions & 6 deletions pl_bolts/models/gans/pix2pix/components.py
Expand Up @@ -3,8 +3,18 @@


class UpSampleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True,
dropout=False):

def __init__(
self,
in_channels,
out_channels,
kernel=4,
strides=2,
padding=1,
activation=True,
batchnorm=True,
dropout=False
):
super().__init__()
self.activation = activation
self.batchnorm = batchnorm
Expand Down Expand Up @@ -32,6 +42,7 @@ def forward(self, x):


class DownSampleConv(nn.Module):

def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
"""
Paper details:
Expand Down Expand Up @@ -61,7 +72,8 @@ def forward(self, x):


class Generator(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels=64):

def __init__(self, in_channels, out_channels):
"""
Paper details:
- Encoder: C64-C128-C256-C512-C512-C512-C512-C512
Expand Down Expand Up @@ -110,13 +122,10 @@ def forward(self, x):
skips_cons = list(reversed(skips_cons[:-1]))
decoders = self.decoders[:-1]

i = 0
for decoder, skip in zip(decoders, skips_cons):
x = decoder(x)
assert self.decoder_channels[i] == x.shape[1], f'{x.shape, self.decoder_channels[i]}'
# print(x.shape, skip.shape)
x = torch.cat((x, skip), axis=1)
i += 1

x = self.decoders[-1](x)
# print(x.shape)
Expand Down
7 changes: 2 additions & 5 deletions pl_bolts/models/gans/pix2pix/pix2pix_module.py
Expand Up @@ -14,11 +14,8 @@ def _weights_init(m):


class Pix2Pix(pl.LightningModule):
def __init__(self,
in_channels,
out_channels,
learning_rate=0.0002,
lambda_recon=200):

def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200):

super().__init__()
self.save_hyperparameters()
Expand Down

0 comments on commit 66317f4

Please sign in to comment.