In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# import torchvision.models as models
# vgg19 = models.vgg19(pretrained=True)
# print(vgg19)

In [None]:
class ConvBlock(nn.Module):
  """ Convolution Block made of convolution, instance norm, 
      and lrelu activation
  """

  def __init__(self, in_channels, out_channels, kernel=3, stride=1):
    super(ConvBlock, self).__init__()
    padding = kernel // 2
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding)
    self.inst_norm = nn.InstanceNorm2d(out_channels)
    self.lrelu = nn.LeakyReLU(0.2)

  def forward(self, x):
    x = self.conv1(x)
    x = self.inst_norm(x)
    x = self.lrelu(x)
    return x

# cb = ConvBlock(3, 6)
# noise_images = torch.randn((64, 3, 256, 256))
# out = cb(noise_images)
# print(out.shape)

In [None]:
class DepthwiseSeparableConv(nn.Module):
  """ depthwise separable convolution layer according to the following posts  
  Honestly not too sure if I implemented the depth-wise convolution correctly
  or not. Let's cross our fingers and hope I did.
  https://medium.com/@zurister/depth-wise-convolution-and-depth-wise-separable-convolution-37346565d4ec
  https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843
  """  

  def __init__(self, in_channels, out_channels, multiplier=4, kernel=3, stride=2):
    """ @spec.requires: in_channels % groups == 0
        *this is accounted for by code: in_channels*multiplier. No need be careful
    """
    super(DepthwiseSeparableConv, self).__init__()
    self.depthwise = nn.Conv2d(in_channels, in_channels*multiplier, kernel_size=kernel, stride=stride, padding=1, groups=in_channels)
    self.pointwise = nn.Conv2d(in_channels*multiplier, out_channels, kernel_size=1, stride=1)    
    self.inst_norm = nn.InstanceNorm2d(out_channels)
    self.lrelu = nn.LeakyReLU(0.2)

  def forward(self, x):
    """ @param x: [N x C x H x W]
        @returns: [N x C x H/stride x W/stride] tensor 
        Note: Assuming stride is either 1 or 2
    """
    x = self.depthwise(x)
    x = self.pointwise(x)
    x = self.inst_norm(x)
    x = self.lrelu(x)
    return x

# dwsc = DepthwiseSeparableConv(128,128,kernel=3).cuda()
# noise_inputs = torch.randn((64,128,256,256)).cuda()
# out = dwsc(noise_inputs)
# print(out.shape)



In [None]:
class DSConv(nn.Module):

  def __init__(self, in_channels, out_channels, kernel=3, stride=2):
    super(DSConv, self).__init__()
    self.depthwise = DepthwiseSeparableConv(in_channels, out_channels,stride=stride)
    self.conv_block1 = ConvBlock(out_channels, out_channels, kernel=1, stride=1)

  def forward(self, x):
    """ @param x: [N x C x H x W]
        @returns: [N x C x H/stride x W/stride] tensor 
        Note: Assuming stride is either 1 or 2
    """
    x = self.depthwise(x)
    # x is [N x out_channels x H/2 x W/2] if stride=2
    x = self.conv_block1(x)
    return x

# ds = DSConv(128,128).cuda()
# noise_inputs = torch.randn((64,128,256,256)).cuda()
# out = ds(noise_inputs)
# print(out.shape)


In [None]:
class InverseResidualBlock(nn.Module):
  """ an inverse residual block :D Let's hope it's right...
  """

  def __init__(self, in_channels, middle_channels=512):
    super(InverseResidualBlock, self).__init__()
    self.conv_block = ConvBlock(in_channels, middle_channels, kernel=1,stride=1)
    self.dconv = DepthwiseSeparableConv(middle_channels, middle_channels//2, kernel=3, stride=1)
    self.conv = nn.Conv2d(middle_channels//2, in_channels, kernel_size=1, stride=1)
    self.inst_norm = nn.InstanceNorm2d(in_channels)

  def forward(self, x):
    """ @param x: [N x C x H x W]
        @returns: [N x C x H x W] tensor
    """
    residual = x
    x = self.conv_block(x)
    x = self.dconv(x)
    x = self.conv(x)
    x = self.inst_norm(x)
    return x + residual

# irb = InverseResidualBlock(64,128).cuda()
# noise_inputs = torch.randn((64,64,64,64)).cuda()
# out = irb(noise_inputs)
# print(out.shape)

In [None]:
class DownConv(nn.Module):

  def __init__(self, in_channels, out_channels):
    super(DownConv, self).__init__()
    self.dconv1 = DSConv(in_channels, out_channels, kernel=3, stride=2)
    self.dconv2 = DSConv(in_channels, out_channels, kernel=3, stride=1)

  def forward(self, x):
    """ @param x: [N x C x H x W]
        @returns x: [N x C x H/2 x W/2]
    """
    residual = x
    residual = self.dconv1(x)
    # residual: [N x out_channels x H/2 x W/2]

    x = F.interpolate(x, scale_factor=0.5)
    x = self.dconv2(x)
    # x: [N x out_channels x H/2 x W/2]
    
    return x + residual

# dc = DownConv(128, 128)
# inputs = torch.randn((64,128,128,128))
# out = dc(inputs)
# print(out.shape)

In [None]:
class UpConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(UpConv, self).__init__()
    self.dconv1 = DSConv(in_channels, out_channels, kernel=3,stride=1)

  def forward(self, x):
    """ @param x: [N x C x H x W]
        @returns x: [N x C x 2*H x 2*W]
    """
    N, C, H, W = x.shape
    x = F.interpolate(x, scale_factor=2)
    x = self.dconv1(x)
    return x

# uc = UpConv(128,128).cuda()
# noise_inputs = torch.randn((64,128,4,4)).cuda()
# out = uc(noise_inputs)
# print(out.shape)

In [None]:
class Generator(nn.Module):
  def __init__(self, in_channels=3):
    super(Generator, self).__init__()
    # encoder stuff
    self.conv1 = ConvBlock(3,64)
    self.conv2 = ConvBlock(64,64)
    self.down_conv1 = DownConv(64,128)
    self.conv3 = ConvBlock(128,128)
    self.dsconv1 = DepthwiseSeparableConv(128,128,stride=1)
    self.down_conv2 = DownConv(128,256)
    self.conv4 = ConvBlock(256,256)

    # residual layers... Do we even need 8??
    # I'll use four for now...
    # irb: inverted residual block 
    self.irb1 = InverseResidualBlock(256)
    self.irb2 = InverseResidualBlock(256)
    self.irb3 = InverseResidualBlock(256)
    self.irb4 = InverseResidualBlock(256)
    
    # decoder stuff
    self.conv5 = ConvBlock(256,256)
    self.up_conv1 = UpConv(256, 128)
    self.dsconv2 = DepthwiseSeparableConv(128,128,stride=1)
    self.conv6 = ConvBlock(128,128)
    self.up_conv2 = UpConv(128, 64)
    self.conv7 = ConvBlock(64,64)
    self.conv8 = ConvBlock(64,64)
    self.final_conv_layer = nn.Conv2d(64,3,kernel_size=3,stride=1, padding=1)
  
  def encode(self, x):
    """ @param x: x is [N x C x H x W] images
        @returns: I think its [N x 256 x H/4 x W/4]?
    """
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.down_conv1(x)
    x = self.conv3(x)
    x = self.dsconv1(x)
    x = self.down_conv2(x)
    x = self.conv4(x)
    return x

  def decode(self, x):
    """ @param x: x is [N x C x H x W] image
        @returns: I think its [N x 3 x 4*H x 4*W]?
    """
    x = self.conv5(x)
    x = self.up_conv1(x)
    x = self.dsconv2(x)
    x = self.conv6(x)
    x = self.up_conv2(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.final_conv_layer(x)
    return x
    
  def residual_forward(self, x):
    """ a forward pass through the residual layers
        @param x: [N, 256, H, W] tensor
        @returns: [N, 256, H, W] tensor
    """
    x = self.irb1(x)
    x = self.irb2(x)
    x = self.irb3(x)
    x = self.irb4(x)
    return x
  
  def forward(self, x):
    """ @param x: [N x C x H x W] images
        @returns: [N x C x H x W] images
    """
    x = self.encode(x)
    x = self.residual_forward(x)
    x = self.decode(x)
    return x


In [None]:
g = Generator().cuda()
noise_images = torch.randn((4,3,256,256)).cuda()
recon_images = g(noise_images)
print(recon_images.shape)

torch.Size([4, 3, 256, 256])




In [None]:
diff = noise_images - recon_images
print(diff.sum())

tensor(-44527.0508, device='cuda:0', grad_fn=<SumBackward0>)


In [None]:
import torch.optim as optim

optimizer = optim.Adam(g.parameters(), lr=0.001)
criterion = nn.MSELoss()

for i in range(1000):

  optimizer.zero_grad()

  recon_images = g(noise_images)
  loss = criterion(recon_images, noise_images)
  loss.backward()
  optimizer.step()

  print(loss.item())