# Reaging GANs

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import typing

# Utilities

## Convolutional Block

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 2) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=stride, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(num_features=out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)

## Transpose Convolutional Block

In [3]:
class TransposeConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 2) -> None:
        super(TransposeConvBlock, self).__init__()
        self.tran_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, 
                               stride=stride, padding=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.tran_conv(x)

# Discriminator

In [4]:
class Discriminator(nn.Module):
    '''Implementation is the same as Pix2Pix

    The forward method will have the the input image and the target image.
        - The target image may not be of the same person, so as result, we'll consider a few more loss functions like the reconstruction loss.
    '''
    def __init__(self, in_channels: int = 3, features: tuple = (64, 128, 256, 512), input_size: int = 128, num_age_groups: int = 3) -> None:
        super(Discriminator, self).__init__()
        self.inital = nn.Sequential(
            nn.Conv2d(in_channels=in_channels*2 + 2, out_channels=features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(negative_slope=0.2)
        )

        layers = []
        in_channels = features[0]

        for ix in range(1, len(features), 1):
            if ix == len(features) - 1:
                layers.append(ConvBlock(in_channels=in_channels, out_channels=features[ix], stride=1))
            else:
                layers.append(ConvBlock(in_channels=in_channels, out_channels=features[ix]))
            in_channels = features[ix]

        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))

        self.model = nn.Sequential(*layers)

        self.input_size = input_size
        
        self.input_embed = nn.Embedding(num_embeddings=num_age_groups, embedding_dim=self.input_size ** 2)
        self.output_embed = nn.Embedding(num_embeddings=num_age_groups, embedding_dim=self.input_size ** 2)

    
    def forward(self, input_tensor: typing.Tuple[torch.Tensor, torch.Tensor], output_tensor: typing.Tuple[torch.Tensor, torch.Tensor], _print: bool = False) -> torch.Tensor:
        '''This is the forward call of the Discriminator
        
        Args:
        ----
            - input_tensor : Tuple[torch.Tensor, torch.Tensor]
                - The first element is the input image
                - The second element is the input age

            - output_tensor : Tuple[torch.Tensor, torch.Tensor]
                - The first element is the output image
                - The second element is the output age

        The considered solution, is similar to the original "Conditional Paper".
            - In this implementation, what we do is, we pass the input age embedding, output image, output age embedding.
                - The idea behind this is that the model will eventually learn the correlation between the pixel values in the input age, output image and the age groups.
                - Here, we can use a single age embedding layer, but I want to experiment with two different embedding layers of the input age and the output age.

        '''

        X_img, X_age = input_tensor
        y_img, y_age = output_tensor

        if _print:
            print(f'X_img: {X_img.shape}, X_age: {X_age.shape}')
            print(f'y_img: {y_img.shape}, y_age: {y_age.shape}')

        X_age_embed = self.input_embed(X_age).reshape(-1, 1, self.input_size, self.input_size)
        y_age_embed = self.output_embed(y_age).reshape(-1, 1, self.input_size, self.input_size)

        if _print:
            print(f'X_age_embed: {X_age_embed.shape}, y_age_embed: {y_age_embed.shape}')

        X = torch.cat([X_img, X_age_embed], dim=1)
        y = torch.cat([y_img, y_age_embed], dim=1)

        if _print:
            print(f'X: {X.shape}, y: {y.shape}')

        x = torch.cat([X, y], dim=1)

        if _print:
            print(f'x: {x.shape}')

        x = self.inital(x)

        if _print:
            print(f'x: {x.shape}')
        return self.model(x)

In [5]:
def test():
    X_img = torch.randn((1, 3, 256, 256))
    X_age = torch.tensor([[2]])

    y_img = torch.randn((1, 3, 256, 256))
    y_age = torch.tensor([[1]])
    
    model = Discriminator(input_size=256)

    print(model)


    print(f'\n[PROCESSING THE DATA...]\n')
    
    z = model.forward((X_img, X_age), (y_img, y_age), _print=True)

    print(f'z: {z.shape}')
    
test()

Discriminator(
  (inital): Sequential(
    (0): Conv2d(8, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0

# Generator

In [6]:
class Generator(nn.Module):
    """This is essentially a U-Net type architecture. But, we'll pass the input age and the output age as well.
    
    So, if the in_channels = 3, then we'll pass 5 channels in the model.
        - 1 will be for the input_age embedding.
        - 2 will be for the output_age embedding.
    
    """

    def __init__(self, in_channels: int = 3, num_age_groups: int = 3, input_size: int = 128) -> None:
        super(Generator, self).__init__()
        
        self.input_embed = nn.Embedding(num_embeddings=num_age_groups, embedding_dim=input_size*input_size)
        self.output_embed = nn.Embedding(num_embeddings=num_age_groups, embedding_dim=input_size*input_size)
        self.input_size = input_size
        
        self.init_down = ConvBlock(in_channels=in_channels + 2, out_channels=64) # 5 -> 64
        
        self.down1 = ConvBlock(in_channels=64, out_channels=128) # 64 -> 128
        self.down2 = ConvBlock(in_channels=128, out_channels=256) # 128 -> 256
        self.down3 = ConvBlock(in_channels=256, out_channels=512) # 256 -> 512
        self.down4 = ConvBlock(in_channels=512, out_channels=512) # 512 -> 512
        self.down5 = ConvBlock(in_channels=512, out_channels=512) # 512 -> 512
        self.down6 = ConvBlock(in_channels=512, out_channels=512) # 512 -> 512
        
        # bottle neck
        self.bottle_neck = nn.Sequential(
            nn.Conv2d(512, 512, 4, 2, 1, padding_mode='reflect'),
            nn.LeakyReLU(negative_slope=0.2)
        ) #
        
        self.up1 = TransposeConvBlock(in_channels=512, out_channels=512) # 512 -> 512
        self.up2 = TransposeConvBlock(in_channels=1024, out_channels=512) # 512 -> 1024
        self.up3 = TransposeConvBlock(in_channels=1024, out_channels=512) # 1024 -> 1024
        self.up4 = TransposeConvBlock(in_channels=1024, out_channels=512) # 1024 -> 512
        self.up5 = TransposeConvBlock(in_channels=1024, out_channels=256) # 512 -> 256
        self.up6 = TransposeConvBlock(in_channels=512, out_channels=128) # 256 -> 128
        self.up7 = TransposeConvBlock(in_channels=256, out_channels=64) # 128 -> 64
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        ) # 64 -> 3
    
    def forward(self, x: torch.Tensor, input_age: torch.Tensor, output_age: torch.Tensor, _print: bool = False) -> torch.Tensor:
        if _print:
            print(f'x: {x.shape}\ninput_age: {input_age.shape}\noutput_age: {output_age.shape}')
        
        inp_embed = self.input_embed(input_age).reshape(input_age.shape[0], 1, self.input_size, self.input_size)
        out_embed = self.output_embed(output_age).reshape(output_age.shape[0], 1, self.input_size, self.input_size)
        
        x = torch.cat([x, inp_embed, out_embed], dim=1)
        
        if _print:
            print(f'x: {x.shape}')
        
        d1 = self.init_down(x)
        if _print:
            print(f'd1: {d1.shape}')
        
        d2 = self.down1(d1)
        if _print:
            print(f'd2: {d2.shape}')
        
        d3 = self.down2(d2)
        if _print:
            print(f'd3: {d3.shape}')
            
        d4 = self.down3(d3)
        if _print:
            print(f'd4: {d4.shape}')
            
        d5 = self.down4(d4)
        if _print:
            print(f'd5: {d5.shape}')
            
        d6 = self.down5(d5)
        if _print:
            print(f'd6: {d6.shape}')
            
        d7 = self.down6(d6)
        if _print:
            print(f'd7: {d7.shape}')
    
        bottle_neck = self.bottle_neck(d7)
        if _print:
            print(f'bottleneck: {bottle_neck.shape}')
        
        u1 = self.up1(bottle_neck)
        if _print:
            print(f'u1: {u1.shape}')
        
        u2 = self.up2(torch.cat([u1, d7], dim=1))
        if _print:
            print(f'u2: {u2.shape}')
        
        u3 = self.up3(torch.cat([u2, d6], dim=1))
        if _print:
            print(f'u3: {u3.shape}')
        
        u4 = self.up4(torch.cat([u3, d5], dim=1))
        if _print:
            print(f'u4: {u4.shape}')
        
        u5 = self.up5(torch.cat([u4, d4], dim=1))
        if _print:
            print(f'u5: {u5.shape}')
        
        u6 = self.up6(torch.cat([u5, d3], dim=1))
        if _print:
            print(f'u6: {u6.shape}')
        
        u7 = self.up7(torch.cat([u6, d2], dim=1))
        if _print:
            print(f'u7: {u7.shape}')
        
        return self.final_up(torch.cat([u7, d1], dim=1))

In [7]:
def test():
    X_img = torch.randn((1, 3, 256, 256))
    X_age = torch.tensor([[2]])
    y_age = torch.tensor([[1]])
    
    model = Generator(input_size=256)

    print(model)


    print(f'\n[PROCESSING THE DATA...]\n')
    
    z = model.forward(X_img, X_age, y_age, _print=True)

    print(f'z: {z.shape}')
    
test()

Generator(
  (input_embed): Embedding(3, 65536)
  (output_embed): Embedding(3, 65536)
  (init_down): ConvBlock(
    (conv): Sequential(
      (0): Conv2d(5, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down1): ConvBlock(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down2): ConvBlock(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down3): ConvBlock(
    (conv): Sequential(