diff --git a/pl_bolts/models/gans/__init__.py b/pl_bolts/models/gans/__init__.py index 5cca383df1..751ba576ef 100644 --- a/pl_bolts/models/gans/__init__.py +++ b/pl_bolts/models/gans/__init__.py @@ -1,7 +1,9 @@ -from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401 -from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401 +from pl_bolts.models.gans.basic.basic_gan_module import GAN +from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN +from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix __all__ = [ "GAN", "DCGAN", + "Pix2Pix", ] diff --git a/pl_bolts/models/gans/pix2pix/__init__.py b/pl_bolts/models/gans/pix2pix/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/gans/pix2pix/components.py b/pl_bolts/models/gans/pix2pix/components.py new file mode 100644 index 0000000000..21edd0c937 --- /dev/null +++ b/pl_bolts/models/gans/pix2pix/components.py @@ -0,0 +1,153 @@ +import torch +from torch import nn + + +class UpSampleConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel=4, + strides=2, + padding=1, + activation=True, + batchnorm=True, + dropout=False + ): + super().__init__() + self.activation = activation + self.batchnorm = batchnorm + self.dropout = dropout + + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding) + + if batchnorm: + self.bn = nn.BatchNorm2d(out_channels) + + if activation: + self.act = nn.ReLU(True) + + if dropout: + self.drop = nn.Dropout2d(0.5) + + def forward(self, x): + x = self.deconv(x) + if self.batchnorm: + x = self.bn(x) + + if self.dropout: + x = self.drop(x) + return x + + +class DownSampleConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True): + """ + Paper details: + - C64-C128-C256-C512-C512-C512-C512-C512 + - All convolutions are 4×4 spatial filters applied with stride 2 + - Convolutions in the encoder downsample by a factor of 2 + """ + super().__init__() + self.activation = activation + self.batchnorm = batchnorm + + self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding) + + if batchnorm: + self.bn = nn.BatchNorm2d(out_channels) + + if activation: + self.act = nn.LeakyReLU(0.2) + + def forward(self, x): + x = self.conv(x) + if self.batchnorm: + x = self.bn(x) + if self.activation: + x = self.act(x) + return x + + +class Generator(nn.Module): + + def __init__(self, in_channels, out_channels): + """ + Paper details: + - Encoder: C64-C128-C256-C512-C512-C512-C512-C512 + - All convolutions are 4×4 spatial filters applied with stride 2 + - Convolutions in the encoder downsample by a factor of 2 + - Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128 + """ + super().__init__() + + # encoder/donwsample convs + self.encoders = [ + DownSampleConv(in_channels, 64, batchnorm=False), # bs x 64 x 128 x 128 + DownSampleConv(64, 128), # bs x 128 x 64 x 64 + DownSampleConv(128, 256), # bs x 256 x 32 x 32 + DownSampleConv(256, 512), # bs x 512 x 16 x 16 + DownSampleConv(512, 512), # bs x 512 x 8 x 8 + DownSampleConv(512, 512), # bs x 512 x 4 x 4 + DownSampleConv(512, 512), # bs x 512 x 2 x 2 + DownSampleConv(512, 512, batchnorm=False), # bs x 512 x 1 x 1 + ] + + # decoder/upsample convs + self.decoders = [ + UpSampleConv(512, 512, dropout=True), # bs x 512 x 2 x 2 + UpSampleConv(1024, 512, dropout=True), # bs x 512 x 4 x 4 + UpSampleConv(1024, 512, dropout=True), # bs x 512 x 8 x 8 + UpSampleConv(1024, 512), # bs x 512 x 16 x 16 + UpSampleConv(1024, 256), # bs x 256 x 32 x 32 + UpSampleConv(512, 128), # bs x 128 x 64 x 64 + UpSampleConv(256, 64), # bs x 64 x 128 x 128 + ] + self.decoder_channels = [512, 512, 512, 512, 256, 128, 64] + self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1) + self.tanh = nn.Tanh() + + self.encoders = nn.ModuleList(self.encoders) + self.decoders = nn.ModuleList(self.decoders) + + def forward(self, x): + skips_cons = [] + for encoder in self.encoders: + x = encoder(x) + + skips_cons.append(x) + + skips_cons = list(reversed(skips_cons[:-1])) + decoders = self.decoders[:-1] + + for decoder, skip in zip(decoders, skips_cons): + x = decoder(x) + # print(x.shape, skip.shape) + x = torch.cat((x, skip), axis=1) + + x = self.decoders[-1](x) + # print(x.shape) + x = self.final_conv(x) + return self.tanh(x) + + +class PatchGAN(nn.Module): + + def __init__(self, input_channels): + super().__init__() + self.d1 = DownSampleConv(input_channels, 64, batchnorm=False) + self.d2 = DownSampleConv(64, 128) + self.d3 = DownSampleConv(128, 256) + self.d4 = DownSampleConv(256, 512) + self.final = nn.Conv2d(512, 1, kernel_size=1) + + def forward(self, x, y): + x = torch.cat([x, y], axis=1) + x0 = self.d1(x) + x1 = self.d2(x0) + x2 = self.d3(x1) + x3 = self.d4(x2) + xn = self.final(x3) + return xn diff --git a/pl_bolts/models/gans/pix2pix/pix2pix_module.py b/pl_bolts/models/gans/pix2pix/pix2pix_module.py new file mode 100644 index 0000000000..33133b4d1a --- /dev/null +++ b/pl_bolts/models/gans/pix2pix/pix2pix_module.py @@ -0,0 +1,78 @@ +import pytorch_lightning as pl +import torch +from torch import nn + +from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN + + +def _weights_init(m): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + torch.nn.init.normal_(m.weight, 0.0, 0.02) + if isinstance(m, nn.BatchNorm2d): + torch.nn.init.normal_(m.weight, 0.0, 0.02) + torch.nn.init.constant_(m.bias, 0) + + +class Pix2Pix(pl.LightningModule): + + def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200): + + super().__init__() + self.save_hyperparameters() + + self.gen = Generator(in_channels, out_channels) + self.patch_gan = PatchGAN(in_channels + out_channels) + + # intializing weights + self.gen = self.gen.apply(_weights_init) + self.patch_gan = self.patch_gan.apply(_weights_init) + + self.adversarial_criterion = nn.BCEWithLogitsLoss() + self.recon_criterion = nn.L1Loss() + + def _gen_step(self, real_images, conditioned_images): + # Pix2Pix has adversarial and a reconstruction loss + # First calculate the adversarial loss + fake_images = self.gen(conditioned_images) + disc_logits = self.patch_gan(fake_images, conditioned_images) + adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits)) + + # calculate reconstruction loss + recon_loss = self.recon_criterion(fake_images, real_images) + lambda_recon = self.hparams.lambda_recon + + return adversarial_loss + lambda_recon * recon_loss + + def _disc_step(self, real_images, conditioned_images): + fake_images = self.gen(conditioned_images).detach() + fake_logits = self.patch_gan(fake_images, conditioned_images) + + real_logits = self.patch_gan(real_images, conditioned_images) + + fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits)) + real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits)) + return (real_loss + fake_loss) / 2 + + def configure_optimizers(self): + lr = self.hparams.learning_rate + gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr) + disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr) + return disc_opt, gen_opt + + def training_step(self, batch, batch_idx, optimizer_idx): + real, condition = batch + + loss = None + if optimizer_idx == 0: + loss = self._disc_step(real, condition) + self.log('PatchGAN Loss', loss) + elif optimizer_idx == 1: + loss = self._gen_step(real, condition) + self.log('Generator Loss', loss) + + return loss + + +if __name__ == '__main__': + pix2pix = Pix2Pix(3, 3) + print(pix2pix(torch.randn(1, 3, 256, 256)).shape) diff --git a/requirements.txt b/requirements.txt index df25d2ae7f..4d3da28520 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.6 -pytorch-lightning>=1.1.1, <1.2 +pytorch-lightning>=1.1.1, <1.2 \ No newline at end of file