# [Re-Aging GAN: Toward Personalized Face Age Transformation](https://openaccess.thecvf.com/content/ICCV2021/papers/Makhmudkhujaev_Re-Aging_GAN_Toward_Personalized_Face_Age_Transformation_ICCV_2021_paper.pdf)

This kernel is the working of the above mentioned paper.

> The paper doesn't provide enough data to build the entire paper, so I will be extracting some information from the [Lifespan Age Transformation Synthesis](https://arxiv.org/pdf/2003.09764.pdf).

## Proposed Method
The authors propose the following solution.

### Overview

- Let $\mathcal{X}$ and $\mathcal{Y}$ be the sets of images and possible ages respectively.
- Given a face image $x\in\mathcal{X}$ and target age randomly drawn from $y'\in\mathcal{Y}$.
- We need to train a single geenrator $G$ such that it can generate a face image $x'$ of a particular age $y'$ corresponding to the identity in $x$.

### Identity Encoder
- Given an image $x$ for age transformation, the identity encoder $Enc$ extracts the identity-related features $f_{id}$ of the image, where $f_{id} = Enc(x)$.
    - The encoder provides features, that supply facial feature at the local lebel and general inforamtion on the face shape.
    - These features are necessary to generate the same looking face shape.
- The paper requires the encoder to focus only on the facial details and as a result, they generate a mask for the images and process the masked images as a result.
    - By this, we can get the background details, and they can be maintained unchanged, and work on the facial values.

> At the architecture level, the identity encoder is designed to have an image-to-feature level convlutional layer followed by downsampling residual blocks.

### Age Modulator
- Age modulator AM is constructed in the form of a CNN.
- Our age input space, $\mathcal{Z}$, is represented by a $50×n$ element vector where n is the number of age classes. When the input age class is $i$, we generate a vector $z_i ∈ \mathcal{Z}$ as
$$z_i = \mathbb{1}_i + v,~~~~v ∼ \mathcal{N}(0, 0.2^2·I)$$
- This takes the identity features $f_{id}$ from the encoder and, by considering given age informtion $y'$, outputs it's reshaped version $f_{aw} = AM(f_{id}, y')$, where $f_{aw}$ is an element age-aware vector.
- To embed target age into $AM$, we add __conditional batch normalization (CBN)__ layers.
    - This is used as a way to encode label information into the network
- The $AM$ itself acts as a downsampling layer with CBN technique producing a compact feature vector used to modulate the decoder layers.

### Decoder
- Takes the identity features ($f_{id}$) and the age aware features ($f_{aw}$) and produces age-transformed face image by $x'=Dec(f_{id}, f_{aw})$.
    - They make the age-aware features to self-guide the decoding process through the modulation operations on unshaped identity features.
- The features are modulated by __adaptive instance normalization (AdaIN)__ layers.
- Then they again mask the image so as to remove irrelavent information added to facial images and then finally add back the background.

### Discriminator
- The discriminator follows a multi-task classification. Hence, the last fully connected layer has a number of output branches to classifify multiple age classes.
- By performing binary classification, each of the braches learns to determine the validity of the image being real $x$ or fake $x'$ of it's age domain.

## Optimization

- The target of the model is to produce images where the identity of the input is preserved, whereas a target age is accurately represented.
- For this purpose, there are three losses which where introduced.
    - Adversarial Loss
    - Reconstruction Loss
    - Cycle Consistency Loss
- The framework operates on 3 input formation.
    - an input image $x$,
    - it's corresponding age label $y$, and
    - randomly sampled target age $y'$ into which the input should be transformed.
- $G$ will produce age-transformed $x'$, $x_{rec}$ is the reconstructed image, and $x_{cycle}$ is the cycle consistency images.

$$x'=G(x,y'),~~~~x_{rec}=G(x,y),~~~~x_{cycle}=G(x',y)$$

### Adversarial Loss

The output of the discriminator corresponds to the particular age domain. Hence, we can think that adversarial loss is conditioned on the age class. We use an adversarial loss is formulated as:

$$\mathcal{L}_{adv}(G,D) = \mathbb{E}_{x, y}[\log D_y(x)] + \mathbb{E}_{x, y'}[\log(1 - D_{y'}(x'))]$$

To understand this loss function, if we look at the LATS paper, we get an idea of what is needed.

They formulate it as:
$$\mathcal{L}_{adv}(G,D) = \mathbb{E}_{x, s}[\log D_s(x)] + \mathbb{E}_{x
, t}[\log(1 - D_{t}(x'))]$$
where, $s$ is the source age-class and $t$ is the target age-class.

As per the LATS paper, they define $D_i$ as the $i^{th}$ output of the discriminator.

### Reconstruction Loss

While training $G$, we have to consider the case whe the age $y$ of the input image and the target image $y'$ belong to the same age group ($y = y'$). In this case, the age-transformed image $x'$ should be as close as possible to $x$.

$$\mathcal{L}_{rec}(G) = \| x - x_{rec}\| _1$$.

### Cycle-Consistency Loss

We can train $G$ to generate images that are realistic and accurate in terms of target age by minimizing the adversarial and reconstruction losses.

Since, we try to generate $x$ back from $x_{cycle}$, we would ideally want both of them to be as similar as possible.

### Full Objective

$$\min_G\max_D[\lambda_{adv}\mathcal{L}_{adv}(G, D) + \lambda_{rec}\mathcal{rec}(G) + \lambda_{cycle}\mathcal{L}_{cyc}]$$

## Experimental Setup (Prescribed hyperparameters)

- They trained their model with a __batch size of 8__ for __30 epochs__ on a single __NVIDIA Titan RTX GPU__.
- As an optimizer, they use __Adam__ with a momentum parameters settings $\beta_1=0.0$ and $\beta_2 = 0.99$ and a __learning rate of $10^{-4}$__.
- They also add R1 Regularization.
- They also have a learning rate scheduler for both the generator and the discriminator. 
- In the beginning of $10$ epochs, we train the model with $\lambda_{rec}=10$, $\lambda{cyc} = 1$, and $\lambda{adv} = 1$ for reconstruction, cycle-consistency, and adversarial losses, respectively. Thereafter, we reduce $\lambda_{rec}=1$ because such training leads to better results.

In [1]:
%load_ext watermark
%watermark -a "Aneesh Aparajit G" -p torch,torchvision,numpy,cv2,matplotlib,scipy,PIL

Author: Aneesh Aparajit G

torch      : 1.12.1
torchvision: 0.13.1
numpy      : 1.21.5
cv2        : 4.6.0
matplotlib : 3.5.2
scipy      : 1.7.3
PIL        : 9.2.0

