Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pix2Pix model #533

Merged
merged 18 commits into from Mar 4, 2021
2 changes: 2 additions & 0 deletions pl_bolts/models/gans/__init__.py
@@ -1,7 +1,9 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401
from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"GAN",
"DCGAN",
"Pix2Pix"
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
]
Empty file.
167 changes: 167 additions & 0 deletions pl_bolts/models/gans/pix2pix/components.py
@@ -0,0 +1,167 @@
import torch
from torch import nn


def center_crop(image, new_shape):
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
h, w = image.shape[-2:]
n_h, n_w = new_shape[-2:]
cy, cx = int(h / 2), int(w / 2)
xmin, ymin = cx - n_w // 2, cy - n_h // 2
xmax, ymax = xmin + n_w, ymin + n_h
cropped_image = image[..., xmin:xmax, ymin:ymax]
return cropped_image


class ConvBlock(nn.Module):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, in_channels, out_channels, use_dropout=False, use_bn=True):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.activation = nn.LeakyReLU(0.2)

if use_bn:
self.batchnorm = nn.BatchNorm2d(out_channels)
self.use_bn = use_bn

if use_dropout:
self.dropout = nn.Dropout()
self.use_dropout = use_dropout

def forward(self, x):
x = self.conv1(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
return x


class UpSampleConv(nn.Module):

def __init__(self, input_channels, use_dropout=False, use_bn=True):
super(self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just confirm once if Upsample is done using nn.Upsample or nn.ConvTranspose2d both work fine. I haven't read Pix2Pix paper so let me check once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thank you for the review. In section 6 of the Pix2Pix paper authors have mentioned that they upsampled the tensors by a factor of 2 but they haven't exactly mentioned if Transposed Conv is used or Upsample followed by Conv layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly checked the paper and found that the PyTorch implementation linked from the author's Lua implementation uses nn.ConvTranspose2d, so shall we follow that architecture unless someone has a strong opinion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I too confirmed that it is nn.ConvTranspose2d. I have referred TensorFlow docs, which give a really nice implementation.

self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)
if use_bn:
self.batchnorm = nn.BatchNorm2d(input_channels // 2)
self.use_bn = use_bn
self.activation = nn.ReLU()
if use_dropout:
self.dropout = nn.Dropout()
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
self.use_dropout = use_dropout

def forward(self, x, skip_con_x):

x = self.upsample(x)
x = self.conv1(x)
skip_con_x = center_crop(skip_con_x, x.shape)
x = torch.cat([x, skip_con_x], axis=1)
x = self.conv2(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
x = self.conv3(x)
if self.use_bn:
x = self.batchnorm(x)
if self.use_dropout:
x = self.dropout(x)
x = self.activation(x)
return x


class DownSampleConv(nn.Module):

def __init__(self, in_channels, use_dropout=False, use_bn=True):
super(self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

if use_bn:
self.batchnorm = nn.BatchNorm2d(in_channels * 2)
self.use_bn = use_bn

if use_dropout:
self.dropout = nn.Dropout()
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
self.use_dropout = use_dropout

self.conv_block1 = ConvBlock(in_channels, in_channels * 2, use_dropout, use_bn)
self.conv_block2 = ConvBlock(in_channels * 2, in_channels * 2, use_dropout, use_bn)

def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.maxpool(x)
return x


class Generator(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels=32, depth=6):
super(self).__init__()

self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)

self.conv_final = nn.Conv2d(hidden_channels,
out_channels,
kernel_size=1)
self.depth = depth

self.contracting_layers = []
self.expanding_layers = []
self.sigmoid = nn.Sigmoid()

# encoding/contracting path of the Generator
for i in range(depth):
self.contracting_layers += [
DownSampleConv(hidden_channels * 2 ** i, use_dropout=True if i < 3 else False)
]

# Upsampling/Expanding path of the Generator
for i in range(depth):
self.expanding_layers += [UpSampleConv(hidden_channels * 2 ** (i + 1))]

self.contracting_layers = nn.ModuleList(self.contracting_layers)
self.expanding_layers = nn.ModuleList(self.expanding_layers)

def forward(self, x):
depth = self.depth
contractive_x = []

x = self.conv1(x)
contractive_x.append(x)

for i in range(depth):
x = self.contracting_layers[i](x)
print(x.shape)
contractive_x.append(x)

for i in range(depth - 1, -1, -1):
x = self.expanding_layers[i](x, contractive_x[i])
print(x.shape)
x = self.conv_final(x)

return self.sigmoid(x)


class Discriminator(nn.Module):

def __init__(self, input_channels, hidden_channels=8):
`super().__init__()`
self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=1)
self.contract1 = DownSampleConv(hidden_channels, use_bn=False)
self.contract2 = DownSampleConv(hidden_channels * 2)
self.contract3 = DownSampleConv(hidden_channels * 4)
self.contract4 = DownSampleConv(hidden_channels * 8)
self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)

def forward(self, x, y):
x = torch.cat([x, y], axis=1)
x0 = self.conv1(x)
x1 = self.contract1(x0)
x2 = self.contract2(x1)
x3 = self.contract3(x2)
x4 = self.contract4(x3)
xn = self.final(x4)
return xn
35 changes: 35 additions & 0 deletions pl_bolts/models/gans/pix2pix/pix2pix_module.py
@@ -0,0 +1,35 @@
import pytorch_lightning as pl
import torch
from torch import nn

from pl_bolts.models.gans.pix2pix.components import Generator, Discriminator


def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)


class Pix2Pix(pl.LightningModule):
def __init__(self,
in_channels: int,
out_channels: int,
hidden_channels: int = 32,
depth: int = 6,
learning_rate: float = 0.0002,
lambda_recon: int = 200):
self.gen = Generator(in_channels, out_channels, hidden_channels, depth)
self.disc = Discriminator(in_channels, hidden_channels=8)
self.learning_rate = learning_rate

# intializing weights
self.gen = self.gen.apply(weights_init)
self.disc = self.disc.apply(weights_init)

self.adv_criterion = nn.BCEWithLogitsLoss()
self.recon_criterion = nn.L1Loss()

self.lambda_recon = lambda_recon