# DeepFill

## Imports

In [None]:


import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import torch.nn as nn
from torch.nn import functional as F
import torch.nn.init as init 
from torch.nn import Parameter
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Dataset Class

In [None]:
class CelebA(Dataset):
    def __init__(self, rootDir):
        self.rootDir = rootDir
        self.images = []
        for filename in os.listdir(self.rootDir):
            if filename.endswith(".jpg"):
                self.images.append(filename)

        self.images_len = len(self.images)

    def __len__(self):
        return self.images_len

    def __getitem__(self, index):
        # load image
        img = self.images[index % self.images_len]
        img_path = os.path.join(self.rootDir, img)
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256, 256))

        # load mask
        mask = self.random_ff_mask(256)

        # convert to torch
        img = (
            torch.from_numpy(img.astype(np.float32) / 255.0)
            .permute(2, 0, 1)
            .contiguous()
        )
        mask = torch.from_numpy(mask.astype(np.float32)).contiguous()

        return (img, mask)

    def random_ff_mask(
        self, mask_size, max_vertex=30, max_length=40, max_angle=4, max_brush_width=10
    ):
        mask = np.zeros((mask_size, mask_size), np.float32)
        numVertex = np.random.randint(max_vertex)
        for i in range(numVertex):
            start_x = np.random.randint(mask_size)
            start_y = np.random.randint(mask_size)
            for j in range(1 + np.random.randint(5)):
                angle = 0.01 + np.random.randint(max_angle)
                if i % 2 == 0:
                    angle = 2 * 3.141 - angle
                length = 10 + np.random.randint(max_length)
                brush_width = 5 + np.random.randint(max_brush_width)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)
                cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_width)
                start_x, start_y = end_x, end_y
        return mask

## Custom layers

In [None]:
## implemnentation  taken of spectral norm from
## https://github.com/avalonstrel/GatedConvolution_pytorch/blob/master/models/spectral.py

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name="weight", power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

In [None]:
# Normal ConvBlock
class Conv2dLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        pad_type="zero",
        activation="elu",
        norm="none",
        sn="False",
    ):
        super().__init__()

        #  Initialize the padiing scheme
        if pad_type == "reflect":
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == "replicate":
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == "zero":
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # Initialize the normalization type
        if norm == "bn":
            self.norm = nn.BatchNorm2d(out_channels)
        elif self.norm == "in":
            self.norm = nn.InstanceNorm2d(out_channels)
        # skipping layer norm here, becuase i dont we are using it anywhere
        elif norm == "none":
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # Initialize the activation function
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "lrelu":
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == "prelu":
            self.activation = nn.PReLU()
        elif activation == "selu":
            self.activation = nn.SELU(inplace=True)
        elif activation == "tanh":
            self.activation = nn.Tanh()
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "none":
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # Initialize the conv layer
        if sn:
            self.conv2d = SpectralNorm(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                )
            )
        else:
            self.conv2d = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
            )

    def forward(self, x):
        x = self.pad(x)
        x = self.conv2d(x)
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

In [None]:
# Transpose ConvBlock
class TransposeConv2dLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        pad_type="zero",
        activation="lrelu",
        norm="none",
        sn=False,
        scale_factor=2,
    ):
        super().__init__()
        self.scale_factor = scale_factor
        self.conv2d = Conv2dLayer(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            pad_type,
            activation,
            norm,
            sn,
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
        return self.conv2d(x)

In [None]:
# Gated ConvBlock
class GatedConv2dLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        pad_type="zero",
        activation="elu",
        norm="none",
        sn="False",
    ):
        super().__init__()

        #  Initialize the padiing scheme
        if pad_type == "reflect":
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == "replicate":
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == "zero":
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # Initialize the normalization type
        if norm == "bn":
            self.norm = nn.BatchNorm2d(out_channels)
        elif self.norm == "in":
            self.norm = nn.InstanceNorm2d(out_channels)
        # skipping layer norm here, becuase i dont we are using it anywhere
        elif norm == "none":
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # Initialize the activation function
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "lrelu":
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == "prelu":
            self.activation = nn.PReLU()
        elif activation == "selu":
            self.activation = nn.SELU(inplace=True)
        elif activation == "tanh":
            self.activation = nn.Tanh()
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "none":
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # Initialize the conv layer
        if sn:
            self.conv2d = SpectralNorm(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                )
            )
            self.mask_conv2d = SpectralNorm(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                )
            )
        else:
            self.conv2d = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
            )
            self.mask_conv2d = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
            )

    def forward(self, x):
        x = self.pad(x)
        conv = self.conv2d(x)
        mask = self.mask_conv2d(x)
        gated_mask = self.sigmoid(mask)
        x = conv * gated_mask
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

In [None]:
# Transpose GatedConvBlock
class TransposeGatedConv2dLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        pad_type="zero",
        activation="lrelu",
        norm="none",
        sn=False,
        scale_factor=2,
    ):
        super().__init__()
        self.scale_factor = scale_factor
        self.gated_conv2d = GatedConv2dLayer(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            pad_type,
            activation,
            norm,
            sn,
            scale_factor=2,
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
        return self.gated_conv2d(x)


## Weight Initialization Function

In [None]:
# Weight Initialization
def weights_init(net, init_type="kaiming", init_gain=0.02):
    """
    Initialize network weights.
    Parameters:
    net (network)  -- network to be initialized
    init_type (str) -- initialization method: normal, xavier & orthogonal
    init_var (float) -- scaling factor
    """

    def init_func(m):
        classname = m.__class__.__name__
        print(classname)
        if hasattr(m, "weight") and classname.find("Conv") != -1:
            if init_type == "normal":
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "xavier":
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == "kaiming":
                init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif init_type == "orthogonal":
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    f"Initialization method {init_type} is not implemented"
                )
        elif classname.find("BatchNorm2d") != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find("Linear") != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # now apply the initialization function here
    net.apply(init_func)

## Generator

In [None]:
# Generator
# input: masked image + mask
# output: filled image
class GatedGenerator(nn.Module):
    def __init__(
        self, in_channels, latent_channels, out_channels, pad_type, activation, norm
    ):
        super().__init__()
        # latent channels = 64
        # in_channels = 4
        # pad_type = zero
        # activation = leaky relu
        # norm = instance norm
        self.coarse = nn.Sequential(
            # encode initial layers
            GatedConv2dLayer(
                in_channels,
                latent_channels,
                7,
                1,
                3,
                pad_type=pad_type,
                activation=activation,
                norm="none",
            ),
            GatedConv2dLayer(
                latent_channels,
                latent_channels * 2,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 2,
                latent_channels * 4,
                4,
                3,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            # Bottleneck layer
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                2,
                dilation=2,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                4,
                dilation=4,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                8,
                dilation=8,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                16,
                dilation=16,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            # decoder
            TransposeConv2dLayer(
                latent_channels * 4,
                latent_channels * 2,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 2,
                latent_channels * 2,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            TransposeConv2dLayer(
                latent_channels * 2,
                latent_channels,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels,
                out_channels,
                7,
                1,
                3,
                pad_type=pad_type,
                activation="tanh",
                norm="norm",
            ),
        )

        self.refinement = nn.Sequential(
            # encoder
            GatedConv2dLayer(
                in_channels,
                latent_channels,
                7,
                1,
                3,
                pad_type=pad_type,
                activation=activation,
                norm="none",
            ),
            GatedConv2dLayer(
                latent_channels,
                latent_channels * 2,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 2,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            # Bottleneck
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                2,
                dilation=2,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                4,
                dilation=4,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                8,
                dilation=8,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                16,
                dilation=16,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            # decoder
            TransposeConv2dLayer(
                latent_channels * 4,
                latent_channels * 2,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels * 2,
                latent_channels * 2,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            TransposeConv2dLayer(
                latent_channels * 2,
                latent_channels,
                3,
                1,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
            ),
            GatedConv2dLayer(
                latent_channels,
                out_channels,
                7,
                1,
                3,
                pad_type=pad_type,
                activation="tanh",
                norm="norm",
            ),
        )

    def forward(self, img, mask):
        first_masked_img = img * (1 - mask) + mask
        coarse_input = torch.cat(
            (first_masked_img, mask), 1
        )  # shape: batch_size, 4, H, W
        coarse_out = self.coarse(coarse_input)
        # refinement network
        refine_masked_img = img * (1 - mask) + coarse_out
        refine_input = torch.cat(
            (refine_masked_img, mask), 1
        )  # shape: batch_size, 4, H, W
        refine_out = self.refinement(refine_input)
        return coarse_out, refine_out

## Discriminator

In [None]:
# Discriminator
# Input: generated image + mask or image + mask
# output: patch of size 30 * 30
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels, latent_channels, pad_type, activation, norm):
        super().__init__()
        # Down-sample the input
        self.block1 = (
            Conv2dLayer(
                in_channels,
                latent_channels,
                7,
                1,
                3,
                pad_type=pad_type,
                activation=activation,
                norm="none",
                sn=True,
            ),
        )
        self.block2 = (
            Conv2dLayer(
                latent_channels,
                latent_channels * 2,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
                sn=True,
            ),
        )
        self.block3 = (
            Conv2dLayer(
                latent_channels * 2,
                latent_channels * 4,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
                sn=True,
            ),
        )
        self.block4 = (
            Conv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
                sn=True,
            ),
        )
        self.block5 = (
            Conv2dLayer(
                latent_channels * 4,
                latent_channels * 4,
                4,
                2,
                1,
                pad_type=pad_type,
                activation=activation,
                norm=norm,
                sn=True,
            ),
        )
        self.block6 = Conv2dLayer(
            latent_channels * 4,
            1,
            4,
            2,
            1,
            pad_type=pad_type,
            activation="none",
            norm="none",
            sn=True,
        )

    def forward(self, img, mask):
        # concat the image and the mask
        x = torch.cat((img, mask), 1)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        return x


## Perceptual Network

In [None]:
# Perceptual Network
# VGG-16 conv4_3 features
class PerceptualNet(nn.Module):
    def __init__(self):
        super(PerceptualNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
        )

    def forward(self, x):
        x = self.features(x)
        return x

## Train Function

In [None]:
# train function
def trainFn(
    gen,
    disc,
    perceptNet,
    opt_g,
    opt_d,
    train_loader,
    epoch,
    L1Loss,
    lambda_l1,
    lambda_perceptual,
    lambda_gan,
):
    # LOOP through the data loader
    for batch_idx, (img, mask) in train_loader:

        # Load img and mask to cuda
        img = img.TO(DEVICE)
        mask = mask.TO(DEVICE)

        # generate output from generator
        coarse_out, refine_out = gen(img, mask)

        # fill the image
        coarse_final_img = img * (1 - mask) + coarse_out * mask
        refine_final_img = img * (1 - mask) + coarse_out * mask

        # Train discriminators
        # fake
        fake_img = disc(refine_final_img.detch(), mask)
        real_img = disc(img.detch(), mask)

        # loss for discriminator and backprop and update
        loss_D = torch.mean(fake_img) - torch.mean(real_img)
        opt_d.zero_grad()
        loss_D.backward()
        opt_d.step()

        # gen loss

        # L1 Loss
        coarse_L1Loss = L1Loss(coarse_final_img, img)
        refine_L1Loss = L1Loss(refine_final_img, img)

        # gan loss (wgan loss)
        fake_img = disc(refine_final_img.detch(), mask)
        gan_loss = -torch.mean(fake_img)

        # perceptual loss
        real_featureMaps = perceptNet(img)
        fake_featureMaps = perceptNet(refine_final_img)
        percept_loss = L1Loss(real_featureMaps, fake_featureMaps)

        # grand total loss, backprop and update
        loss_G = (
            lambda_l1 * coarse_L1Loss
            + lambda_l1 * refine_L1Loss
            + lambda_perceptual * percept_loss
            + lambda_gan * gan_loss
        )
        opt_g.zero_grad()
        loss_G.backward()
        opt_g.step()

        # TODO: PRINT LOGS HERE

## Main Function

In [None]:
# main function
def main():

    # All constants

    # Device
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # CONSTANTS for generator
    IN_CHANNELS = 3
    LATENT_CHANNELS = 64
    OUT_CHANNELS = 3
    PAD_TYPE = "zero"
    NORM = "in"
    ACTIVATION = "lrelu"

    # Learning rate constants
    LR_G = 1e-4
    LR_D = 4e-4
    BETA1 = 0.5
    BETA2 = 0.999
    lr_decrease_factor = 0.5
    lr_decrease_epoch = 10

    # All necessary paths
    ROOTDIR_PATH = r"G:\DeepFill\Data\data256x256"

    # Training constants
    BATCH_SIZE = 32
    NUM_EPOCHS = 1

    # Loss Constants
    lambda_l1 = 100
    lambda_perceptual = 10
    lambda_gan = 1

    # Instantiate all networks

    #  CREATE GENERATOR -> make gen object and initialize weights
    gen = GatedGenerator(
        in_channels=IN_CHANNELS,
        latent_channels=LATENT_CHANNELS,
        out_channels=OUT_CHANNELS,
        pad_type=PAD_TYPE,
        activation=ACTIVATION,
        norm=NORM,
    ).to(DEVICE)

    weights_init(net=gen)

    #  CREATE DISCRIMINATOR -> make disc object and initialize weights
    disc = PatchDiscriminator(
        in_channels=IN_CHANNELS,
        latent_channels=LATENT_CHANNELS,
        pad_type=PAD_TYPE,
        activation=ACTIVATION,
        norm=NORM,
    ).to(DEVICE)

    weights_init(net=disc)

    # Instantiate Perceptual net

    def load_dict(process_net, pretrained_net):
        """
        Function to load pretrained network's state dict to our current network
        """
        # Get the dict from pretrained net
        # idk if state_dict()will be there or not
        pretrained_dict = pretrained_net.state_dict()
        # Get the dict from process_net
        process_dict = process_net.state_dict()
        # Delete the extra keys from pretrained_dict that do not belong to process_dict
        pretrained_dict = {k: v for k, v in pretrained_dict if k in process_dict}
        # Update process dict using pretrained_dict
        process_dict.update(pretrained_dict)
        # Load the updated dict to processing network
        process_net.load_state_dict(process_dict)
        return process_net

    perceptNet = PerceptualNet().to(DEVICE)
    vgg16 = torch.load("./Data/vgg16_pretrained.pth").to(DEVICE)
    load_dict(perceptNet, vgg16)
    for param in perceptNet.parameters():
        param.requires_grad = False

    # Optimizers
    opt_g = torch.optim.Adam(gen.parameters(), lr=LR_G, betas=(BETA1, BETA2))
    opt_d = torch.optim.Adam(disc.parameters(), lr=LR_D, betas=(BETA1, BETA2))

    # Loss functions
    L1Loss = nn.L1Loss()

    # Learning rate scheduler
    def adjust_lr(lr_in, optimizer, epoch, decrease_factor, lr_decrease_epoch):
        """
        Set the lr to (decreased_factor * lr) to every lr_decrease_epoch
        """
        lr = lr_in * (decrease_factor ** (epoch // lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    # Initialize training data
    train_dataset = dataset.CelebA(ROOTDIR_PATH)
    train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)

    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch}")
        trainFn(
            gen,
            disc,
            perceptNet,
            opt_g,
            opt_d,
            train_loader,
            epoch,
            L1Loss,
            lambda_l1,
            lambda_perceptual,
            lambda_gan,
        )
        # Decrease learning rate
        adjust_lr(LR_G, opt_g, epoch + 1, lr_decrease_factor, lr_decrease_epoch)
        adjust_lr(LR_D, opt_d, epoch + 1, lr_decrease_factor, lr_decrease_epoch)

        # TODO: FUNCTION TO SAVE MODEL

        # TODO: SAMPLE IMAGES TO SAVE IN A FOLDER
