# Setup

In [1]:
from google.colab import files
files.upload()  # Upload your kaggle.json here.

!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

# Kaggle datasets here.

# Animal Faces.
!mkdir animal-faces
%cd animal-faces/
!kaggle datasets download -d andrewmvd/animal-faces
!unzip -q animal-faces.zip
!rm animal-faces.zip
%cd ../

# Cat Faces by fferlito.
!mkdir fferlito-cat-faces
%cd fferlito-cat-faces/
!kaggle datasets download -d vincenttu/catfacesdatasetfferlito
!unzip -q catfacesdatasetfferlito.zip
!rm catfacesdatasetfferlito.zip
%cd ../

import tarfile

tgz = tarfile.open(r"/content/fferlito-cat-faces/dataset-part1.tar")
tgz.extractall(path="/content/fferlito-cat-faces")
tgz.close()
tgz = tarfile.open(r"/content/fferlito-cat-faces/dataset-part2.tar")
tgz.extractall(path="/content/fferlito-cat-faces")
tgz.close()
tgz = tarfile.open(r"/content/fferlito-cat-faces/dataset-part3.tar")
tgz.extractall(path="/content/fferlito-cat-faces")
tgz.close()

!rm fferlito-cat-faces/dataset-part1.tar
!rm fferlito-cat-faces/dataset-part2.tar
!rm fferlito-cat-faces/dataset-part3.tar

# Cat Faces by Spandan.
!mkdir spandan-cat-faces
%cd spandan-cat-faces/
!kaggle datasets download -d spandan2/cats-faces-64x64-for-generative-models
!unzip -q cats-faces-64x64-for-generative-models.zip
!rm cats-faces-64x64-for-generative-models.zip
%cd ../

# Cat Faces by waifuai.
!mkdir waifuai-cat-faces
%cd waifuai-cat-faces/
!kaggle datasets download -d waifuai/cat2dog
!unzip -q cat2dog.zip
!rm cat2dog.zip
%cd ../

# CelebA.
!mkdir celeba
%cd celeba/
!kaggle datasets download -d zuozhaorui/celeba
!unzip -q celeba.zip
!rm celeba.zip
%cd ../

Saving kaggle.json to kaggle.json
/content/animal-faces
Downloading animal-faces.zip to /content/animal-faces
 98% 685M/696M [00:05<00:00, 154MB/s]
100% 696M/696M [00:05<00:00, 141MB/s]
/content
/content/fferlito-cat-faces
Downloading catfacesdatasetfferlito.zip to /content/fferlito-cat-faces
 97% 256M/265M [00:01<00:00, 136MB/s]
100% 265M/265M [00:01<00:00, 145MB/s]
/content
/content/spandan-cat-faces
Downloading cats-faces-64x64-for-generative-models.zip to /content/spandan-cat-faces
 96% 92.0M/96.0M [00:00<00:00, 88.3MB/s]
100% 96.0M/96.0M [00:00<00:00, 131MB/s] 
/content
/content/waifuai-cat-faces
Downloading cat2dog.zip to /content/waifuai-cat-faces
 91% 25.0M/27.4M [00:00<00:00, 43.8MB/s]
100% 27.4M/27.4M [00:00<00:00, 91.4MB/s]
/content
/content/celeba
Downloading celeba.zip to /content/celeba
 99% 2.63G/2.64G [00:18<00:00, 173MB/s]
100% 2.64G/2.64G [00:18<00:00, 152MB/s]
/content


In [2]:
!pip install albumentations --upgrade -q

[K     |████████████████████████████████| 102 kB 9.0 MB/s 
[K     |████████████████████████████████| 47.6 MB 57 kB/s 
[?25h

In [3]:
from __future__ import print_function
from IPython.display import HTML
import argparse
import os
import gc
import random
import imageio
from glob import glob
from PIL import Image
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
#%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

import warnings
warnings.simplefilter('ignore')

In [4]:
!pip install wandb -qqq
import wandb
# from wandb.keras import WandbCallback
wandb.login()

[K     |████████████████████████████████| 1.7 MB 7.9 MB/s 
[K     |████████████████████████████████| 97 kB 5.8 MB/s 
[K     |████████████████████████████████| 140 kB 32.6 MB/s 
[K     |████████████████████████████████| 180 kB 47.8 MB/s 
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Utility Functions

## Checkpoint Directory Function

In [5]:
# For checkpointing.
def return_ckpt_dir(GAN_player):
  return '{}/{}_{}.pth'.format(c.ckpt_dir, c.model_name, GAN_player)  # We are following this name convention.

## Seed Function

In [6]:
def seed_everything(seed=999):
  # Set random seed for reproducibility.
  print("Seeding everything...")
  random.seed(seed)
  torch.manual_seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.cuda.manual_seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

## Weight Init Function

(```mean = 0```, ```stdev = 0.02```)

In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        try:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        except:
            print("Unable to initialize Conv layer weights.")
    elif classname.find('BatchNorm') != -1:
        try:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        except:
          print("Unable to initialize BatchNorm weights.")

# Configurations

In [8]:
class Config:

    # Project name.
    project_name = "DCGAN_TL2021"

    # Project run name.
    project_run_name = "CycleGAN"

    # Model name.
    model_name = "CycleGAN"

    # Checkpoint dir.
    ckpt_dir = "."

    # Number of workers for dataloader.
    workers = 0

    # How many images out of CelebA dataset to select from.
    total_images = 1000

    # How many images to select out of total_images.
    num_images = 64

    # Number of GPUs available. Use 0 for CPU mode..
    ngpu = 1

    # Choosing the device.
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

    seed_everything()

Seeding everything...


In [9]:
c = Config()

# Preprocessing

In [10]:
train_cat_path = r"/content/animal-faces/afhq/train/cat"  # path_c
train_dog_path = r"/content/animal-faces/afhq/train/dog"
train_wild_path = r"/content/animal-faces/afhq/train/wild"

val_cat_path = r"/content/animal-faces/afhq/val/cat"  # path_c_1
val_dog_path = r"/content/animal-faces/afhq/val/dog"
val_wild_path = r"/content/animal-faces/afhq/val/wild"

celeba_path = r"/content/celeba/img_align_celeba/img_align_celeba"

fferlito_path = list(glob("/content/fferlito-cat-faces/*"))  # path_c_fferlito
spandan_path = r"/content/spandan-cat-faces/cats/"  # path_c_spandan
waifuai_path = [os.path.join(r"/content/waifuai-cat-faces/cat2dog", p) for p in ["trainA", "testA"]]  # path_c_waifuai

## Animal-Faces Cat & CelebA Preprocessing

In [11]:
class AFCatCelebADataset(Dataset):
    def __init__(self, path_c, path_h, image_size, 
                 path_c_1=None, path_c_fferlito=None, path_c_spandan=None, path_c_waifuai=None):
        super().__init__()

        self.path_c = list(glob(os.path.join(path_c, "*")))

        # Adding extra cat images from the val folder.
        if path_c_1:
            self.path_c.extend(list(glob(os.path.join(path_c_1, "*"))))
          
        # Adding extra cat images from fferlito repo.
        if path_c_fferlito:
            assert isinstance(path_c_fferlito, list), "path_c_fferlito should be a list."
            assert len(path_c_fferlito) == 3, "path_c_fferlito should be of length 3."
            for path in path_c_fferlito:
                self.path_c.extend(list(glob(os.path.join(path, "*"))))

        # Adding extra cat images from spandan dataset.
        if path_c_spandan:
            self.path_c.extend(list(glob(os.path.join(spandan_path, "*.jpg"))))  # cats folder is then removed.
            

      # Adding extra cat images from waifuai dataset.
        if path_c_waifuai:
            assert isinstance(path_c_waifuai, list), "path_c_waifuai should be a list."
            assert len(path_c_waifuai) == 2, "path_c_waifuai should be of length 2."
            for path in path_c_waifuai:
                self.path_c.extend(list(glob(os.path.join(path, "*"))))

        self.path_h = list(glob(os.path.join(path_h, "*")))

        # Method 1 for preparing the data.
        self.length_dataset = max(len(self.path_c), len(self.path_h)) # 5153, 203k

        self.c_length = len(self.path_c)
        self.h_length = len(self.path_h)

        self.image_size = image_size  # [H, W]

        self.transform = A.Compose([
          A.Resize(self.image_size[0], self.image_size[1]),
          A.CenterCrop(self.image_size[0], self.image_size[1]),

          A.HorizontalFlip(p=0.5),

          A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
          ToTensorV2(),
        ], additional_targets={"image0": "image"},)

        # Method 2 for preparing the data.
        self.path = list(zip(self.path_c, self.path_h))  # Goes till the shortest list ends; a list of tuples of length 5153.

    def __len__(self):
        return len(self.path)
        # return self.length_dataset

    def __getitem__(self, index):
        # Method 1 for preparing the data.
        # cat_path = self.path_c[index % self.c_length]
        # human_path = self.path_h[index % self.h_length]
        
        # cat_img = np.asarray(Image.open(cat_path))
        # human_img = np.asarray(Image.open(human_path))

        # aug_imgs = self.transform(image=cat_img, image0=human_img)
        # cat_img = aug_imgs["image"]
        # human_img = aug_imgs["image0"]

        # return {"cat_img": cat_img,
        #         "human_img": human_img}

        # Method 2 for preparing the data.
        cat_path, human_path = self.path[index]
        cat_img = np.asarray(Image.open(cat_path))
        human_img = np.asarray(Image.open(human_path))
        aug_imgs = self.transform(image=cat_img, image0=human_img)
        cat_img = aug_imgs["image"]
        human_img = aug_imgs["image0"]

        return {"cat_img": cat_img,
                "human_img": human_img}

## Fixed CelebA Image Dataset Preprocessing

In [12]:
class FixedCelebADataset(Dataset):
    def __init__(self, path, image_size, num_images, total_images):
        super().__init__()

        path = list(glob(os.path.join(path, "*")))
        assert total_images <= len(path), "total_images should be <= len(path)."
        assert num_images <= total_images, "num_images should be <= total_images."

        self.image_size = image_size
        self.num_images = num_images
        self.total_images = total_images
        rnd_choices = np.random.choice(total_images, num_images, replace=False)
        self.path = np.array(path)[rnd_choices]

        self.transform = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.CenterCrop(self.image_size[0], self.image_size[1]),
            A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ToTensorV2(),
        ])

    def __len__(self):
        return len(self.path)

    def __getitem__(self, index):
          img = np.asarray(Image.open(self.path[index]))
          img = self.transform(image=img)["image"]

          return {"images": img,
                  }

# Model Building

## UNet-DCGAN

Ref: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [13]:
class DCGANGenerator(nn.Module):
    def __init__(self, nc, ngf):
        super().__init__()

        self.unet_left_block1 = self._unet_left_block(nc, ngf, 4, 2, 1)
        self.unet_left_block2 = self._unet_left_block(ngf, ngf * 2, 4, 2, 1)
        self.unet_left_block3 = self._unet_left_block(ngf * 2, ngf * 4, 4, 2, 1)
        self.unet_left_block4 = self._unet_left_block(ngf * 4, ngf * 8, 4, 2, 1)
        self.unet_left_block5 = self._unet_left_block(ngf * 8, nz, 4, 1, 0)        

        self.unet_right_block1 = self._unet_right_block(nz, ngf * 8, 4, 1, 0)
        self.unet_right_block2 = self._unet_right_block(ngf * 8, ngf * 4, 4, 2, 1)
        self.unet_right_block3 = self._unet_right_block(ngf * 4, ngf * 2, 4, 2, 1)
        self.unet_right_block4 = self._unet_right_block(ngf * 2, ngf, 4, 2, 1)
        self.unet_right_block5 = nn.Sequential(
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1),
            nn.Tanh()
        )

    def _unet_right_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def _unet_left_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        # B: 128
        # nz: 100
        # ngf: 64
        # nc: 3

        # input: (B x (nc) x 64 x 64)

        x = input

        # Left block.
        x1 = self.unet_left_block1(x)  # (B x (ngf) x 32 x 32)
        x2 = self.unet_left_block2(x1)  # (B x (ngf*2) x 16 x 16)
        x3 = self.unet_left_block3(x2)  # (B x (ngf*4) x 8 x 8)
        x4 = self.unet_left_block4(x3)  # (B x (ngf*8) x 4 x 4)
        x5 = self.unet_left_block5(x4)  # (B x nz x 1 x 1)

        # Right block.
        z1 = self.unet_right_block1(x5)  # (B x (ngf*8) x 4 x 4)
        z1 = z1 + x4
        z2 = self.unet_right_block2(z1)  # (B x (ngf*4) x 8 x 8)
        z2 = z2 + x3
        z3 = self.unet_right_block3(z2)  # (B x (ngf*2) x 16 x 16)
        z3 = z3 + x2
        z4 = self.unet_right_block4(z3)  # (B x (ngf) x 32 x 32)
        z4 = z4 + x1
        z5 = self.unet_right_block5(z4)  # (B x (nc) x 64 x 64)
        z5 = z5 + x

        return z5  # (B x (nc) x 64 x 64)

In [14]:
class DCGANDiscriminator(nn.Module):
    def __init__(self, nc, ndf):
        super().__init__()

        # B: 128
        # nc: 3
        # ndf: 64

        self.main = nn.Sequential(
            # input: (B x (nc) x 64 x 64)

            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),  # B x (ndf) x 32 x 32
            nn.LeakyReLU(0.2, inplace=True),
            
            self._block(ndf, ndf * 2, 4, 2, 1),  # B x (ndf*2) x 16 x 16
            self._block(ndf * 2, ndf * 4, 4, 2, 1),  # B x (ndf*4) x 8 x 8
            self._block(ndf * 4, ndf * 8, 4, 2, 1),  # B x (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),  # B x 1 x 1 x 1
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        # input: (B x (nc) x 64 x 64)

        return self.main(input)  # B x 1 x 1 x 1

## PatchGAN 
Ref: https://github.com/znxlwm/pytorch-pix2pix/blob/3059f2af53324e77089bbcfc31279f01a38c40b8/network.py#L104

In [15]:
class PatchGANGenerator(nn.Module):
    def __init__(self, d=64):
        super().__init__()

        # Unet Encoder

        self.conv1 = nn.Conv2d(3, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d * 2)
        self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d * 4)
        self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d * 8)
        self.conv5 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm2d(d * 8)
        self.conv6 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm2d(d * 8)

        # self.conv7 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        # self.conv7_bn = nn.BatchNorm2d(d * 8)
        # self.conv8 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        # # self.conv8_bn = nn.BatchNorm2d(d * 8)

        # UNet Decoder

        # self.deconv1 = nn.ConvTranspose2d(d * 8, d * 8, 4, 2, 1)
        # self.deconv1_bn = nn.BatchNorm2d(d * 8)
        # self.deconv2 = nn.ConvTranspose2d(d * 8 * 2, d * 8, 4, 2, 1)
        # self.deconv2_bn = nn.BatchNorm2d(d * 8)

        self.deconv3 = nn.ConvTranspose2d(d * 8, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d * 8)
        self.deconv4 = nn.ConvTranspose2d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d * 8)
        self.deconv5 = nn.ConvTranspose2d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm2d(d * 4)
        self.deconv6 = nn.ConvTranspose2d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm2d(d * 2)
        self.deconv7 = nn.ConvTranspose2d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm2d(d)
        self.deconv8 = nn.ConvTranspose2d(d * 2, 3, 4, 2, 1)

    def forward(self, input):
        # input: (B, 3, 64, 64)

        e1 = self.conv1(input)  # (B, 64, 32, 32)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))  # (B, 128, 16, 16)
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))  # (B, 256, 8, 8)
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))  # (B, 512, 4, 4)
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))  # (B, 512, 2, 2)
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))  # (B, 512, 1, 1)

        # e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        # e8 = self.conv8(F.leaky_relu(e7, 0.2))
        # e8 = self.conv8_bn(self.conv8(F.leaky_relu(e7, 0.2)))
        # d1 = F.dropout(self.deconv1_bn(self.deconv1(F.relu(e8))), 0.5, training=True)
        # d1 = torch.cat([d1, e7], 1)
        # d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        # d2 = torch.cat([d2, e6], 1)

        d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(e6))), 0.5, training=True)  # (B, 512, 2, 2)
        d3 = torch.cat([d3, e5], 1)  # (B, 1024, 2, 2)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))  # (B, 512, 4, 4)

        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)

        d4 = torch.cat([d4, e4], 1)  # (B, 1024, 4, 4)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))  # (B, 256, 8, 8)
        d5 = torch.cat([d5, e3], 1)  # (B, 512, 8, 8)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))  # (B, 128, 16, 16)
        d6 = torch.cat([d6, e2], 1)  # (B, 256, 16, 16)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))  # (B, 64, 32, 32)
        d7 = torch.cat([d7, e1], 1)  # (B, 128, 32, 32)
        d8 = self.deconv8(F.relu(d7))  # (B, 3, 64, 64)
        o = F.tanh(d8)  # (B, 3, 64, 64)

        return o  # (B, 3, 64, 64)

In [16]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, d=64):
        super().__init__()
        
        self.conv1 = nn.Conv2d(6, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d * 2)
        self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d * 4)
        self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 1, 1)
        self.conv4_bn = nn.BatchNorm2d(d * 8)
        self.conv5 = nn.Conv2d(d * 8, 1, 4, 1, 1)

    def forward(self, input, label):
        # input: (B, 3, 64, 64)
        # label: (B, 3, 64, 64)

        x = torch.cat([input, label], 1)  # (B, 6, 64, 64)
        x = F.leaky_relu(self.conv1(x), 0.2)  # (B, 64, 32, 32)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)  # (B, 128, 16, 16)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)  # (B, 256, 8, 8)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)  # (B, 512, 7, 7)
        x = F.sigmoid(self.conv5(x))  # (B, 1, 6, 6)

        return x  # (B, 1, 6, 6)

## CycleGAN

Ref: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/GANs/CycleGAN/generator_model.py

In [17]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class CycleGANGenerator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()

        # input: (B, 3, 64, 64)

        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, 
                      stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )  # (B, 64, 64, 64)

        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),  # (B, 128, 32, 32)
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),  # (B, 256, 16, 16)
            ]
        )

        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]  # (B, 256, 16, 16) ∀ Residuals
        )

        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, 
                          stride=2, padding=1, output_padding=1),  # (B, 128, 32, 32)
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, 
                          stride=2, padding=1, output_padding=1),  # (B, 64, 64, 64)
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, 
                              stride=1, padding=3, padding_mode="reflect")  # (B, 3, 64, 64)

    def forward(self, x):
        # input: (B, 3, 64, 64)

        x = self.initial(x)  # (B, 64, 64, 64)
        
        for layer in self.down_blocks:
            x = layer(x)  # (B, 256, 16, 16)

        x = self.res_blocks(x)  # (B, 256, 16, 16)

        for layer in self.up_blocks:
            x = layer(x)  # (B, 64, 64, 64)

        return torch.tanh(self.last(x))  # (B, 3, 64, 64)

In [18]:
class CycleGANDiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class CycleGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()

        # input: (B, 3, 64, 64)

        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )  # (B, 64, 32, 32)

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            # 0: (B, 128, 16, 16)
            # 1: (B, 256, 8, 8)
            # 2: (B, 512, 7, 7)

            layers.append(CycleGANDiscriminatorBlock(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))  # (B, 1, 6, 6)


        self.model = nn.Sequential(*layers)

    def forward(self, x):
        # input: (B, 3, 64, 64)

        x = self.initial(x)  # (B, 64, 32, 32)
        return torch.sigmoid(self.model(x))  # (B, 1, 6, 6)

# Hyperparameters

In [19]:
class Hparams:

    # Batch size during training.
    batch_size = 128

    # Spatial size of training images. All images will be resized to this
    # size using a transformer.
    image_size = 64

    # Number of channels in the training images. For color images this is 3.
    nc = 3

    # Size of z latent vector (i.e. size of generator input).
    nz = 100

    # DCGANGenerator size of feature maps.
    ngf = 64

    # DCGANDiscriminator size of feature maps.
    ndf = 64

    # Number of training epochs.
    num_epochs = 90

    # Learning rate for optimizers.
    lr = 1e-5

    # Beta1 hyperparam for Adam optimizers.
    beta1 = 0.5

    # Beta2 hyperparam for Adam optimizers.
    beta2 = 0.999

    # PatchGAN size of feature maps.
    d = 64

    # PatchGAN L1 aux loss lambda for reducing blurryness. 
    l1_lambda = 100

    # CycleGAN lambda cycle.
    lambda_cycle = 10

    # CycleGAN identity lambda.
    lambda_identity = 0

    # CycleGAN number of residual blocks.
    num_residuals = 9

In [20]:
h = Hparams()

# Training

## Get Models Function

In [21]:
def get_model(model_type):
    assert model_type in ["dcgan", "patchgan", "cyclegan"], f"{model_type} is not a valid model."

    generators, discriminators = [], []

    if model_type == "dcgan" or model_type == "patchgan":
        # Create the generator.
        if model_type == "dcgan":
          netG = DCGANGenerator(h.nc, h.ngf).to(c.device)
        else:
          netG = PatchGANGenerator(h.d).to(c.device)

        # Handle multi-gpu if desired.
        if (c.device.type == 'cuda') and (c.ngpu > 1):
            netG = nn.DataParallel(netG, list(range(c.ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.02.
        netG.apply(weights_init)

        # Create the Discriminator.
        if model_type == "dcgan":
          netD = DCGANDiscriminator(h.nc, h.ngf).to(c.device)
        else:
          netD = PatchGANDiscriminator(h.d).to(c.device)

        # Handle multi-gpu if desired.
        if (c.device.type == 'cuda') and (c.ngpu > 1):
            netD = nn.DataParallel(netD, list(range(c.ngpu)))

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        netD.apply(weights_init)

        generators.append(netG)
        discriminators.append(netD)

    elif model_type == "cyclegan":
        # Create the Generators.
        gen_C = CycleGANGenerator(img_channels=h.nc, num_residuals=h.num_residuals).to(c.device)
        gen_H = CycleGANGenerator(img_channels=h.nc, num_residuals=h.num_residuals).to(c.device)

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        # gen_C.apply(weights_init)
        # gen_H.apply(weights_init)

        # Create the Discriminators
        disc_C = CycleGANDiscriminator(in_channels=h.nc).to(c.device)
        disc_H = CycleGANDiscriminator(in_channels=h.nc).to(c.device)

        # Apply the weights_init function to randomly initialize all weights
        #  to mean=0, stdev=0.2.
        # disc_C.apply(weights_init)
        # disc_H.apply(weights_init)

        generators.append(gen_C)
        generators.append(gen_H)
        discriminators.append(disc_C)
        discriminators.append(disc_H)

    return generators, discriminators

## Get Criterions, Optimizers, and Scalers Function

In [22]:
def get_criterions(model_type, generators, discriminators):
    assert model_type in ["dcgan", "patchgan", "cyclegan"], f"{model_type} is not a valid model."

    criterions = []

    d_scaler = torch.cuda.amp.GradScaler()
    g_scaler = torch.cuda.amp.GradScaler()

    if model_type == "dcgan":
        netG, netD = generators[0], discriminators[0]

        criterion = nn.BCELoss()
        optimizerG = optim.Adam(netG.parameters(), lr=h.lr, betas=(h.beta1, h.beta2))
        optimizerD = optim.Adam(netD.parameters(), lr=h.lr, betas=(h.beta1, h.beta2))

        criterions.append(criterion)

    elif model_type == "patchgan":
        netG, netD = generators[0], discriminators[0]

        criterion = nn.BCEWithLogitsLoss()
        l1_loss = nn.L1Loss()
        optimizerG = optim.Adam(netG.parameters(), lr=h.lr, betas=(h.beta1, h.beta2))
        optimizerD = optim.Adam(netD.parameters(), lr=h.lr, betas=(h.beta1, h.beta2))

        criterions.append(criterion)
        criterions.append(l1_loss)

    elif model_type == "cyclegan":
        gen_C, gen_H = generators
        disc_C, disc_H = discriminators

        criterion = nn.MSELoss()
        l1_loss = nn.L1Loss()
        optimizerG = optim.Adam(
            list(gen_C.parameters()) + list(gen_H.parameters()),
            lr=h.lr,
            betas=(h.beta1, h.beta2),
        )
        optimizerD = optim.Adam(
            list(disc_H.parameters()) + list(disc_C.parameters()),
            lr=h.lr,
            betas=(h.beta1, h.beta2),
        )

        criterions.append(criterion)
        criterions.append(l1_loss)

    return criterions, optimizerG, optimizerD, g_scaler, d_scaler

## Create the Models, Criterions, Optimizers, Scalers, and DataLoaders

In [23]:
# Model type: ["dcgan", "patchgan", "cyclegan"].
model_type = "cyclegan"

# Get models.
generators, discriminators = get_model(model_type)

# Get the criterions, optimizers, and scalers.
criterions, optimizerG, optimizerD, g_scaler, d_scaler = get_criterions(model_type, generators, discriminators)

# Get the dataset and loaders. 
train_cat_celeba_dataset = AFCatCelebADataset(path_c=train_cat_path, path_h=celeba_path, 
                               image_size=(h.image_size, h.image_size, h.nc),
                               path_c_1=val_cat_path,
                               path_c_fferlito=fferlito_path,
                               path_c_spandan=spandan_path,
                               path_c_waifuai=waifuai_path)

train_cat_celeba_loader = DataLoader(train_cat_celeba_dataset, 
                           batch_size=h.batch_size,
                           shuffle=True,
                           num_workers=c.workers,
                           drop_last=True)

fixed_images_dataset = FixedCelebADataset(path=celeba_path, 
                                         image_size=(h.image_size, h.image_size, h.nc),
                                         num_images=c.num_images,
                                         total_images=c.total_images)

fixed_images_loader = DataLoader(fixed_images_dataset,
                                batch_size=c.num_images,
                                shuffle=False,
                                num_workers=c.workers)

## Pick & Run Training Script

In [24]:
# Checking our GPU.
import torch
import platform 

if torch.cuda.is_available():
  print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
else:
  print("\n[INFO] GPU not found. Using CPU: {}\n".format(platform.processor()))

[INFO] Using GPU: Tesla P100-PCIE-16GB



In [25]:
# Call this if the run never finished.
# run.finish()

### Prime the Training Script

In [26]:
# Validation Method:
#     - Manual evaluation with no rating ✓

# Saving Methods: 
#     - fixed images ✓
#     - systematic saving ✓
#     - last epoch saving ✓
#     - best val_loss saving

os.makedirs('gifs/', exist_ok=True)

run = wandb.init(project=c.project_name, name=f"{c.project_run_name}")

d_ckpt_dir, g_ckpt_dir = return_ckpt_dir("D"), return_ckpt_dir("G")
img_list, data = [], []

# Wandb tables for fixed image visualization.
my_table = wandb.Table(columns=["Epoch", 
                                "I-th Iteration", 
                                "Fixed Images"])
for fixed_images in fixed_images_loader:
  img_list.append(vutils.make_grid(fixed_images["images"], padding=2, normalize=True))
  my_table.add_data(0, 0, 
                    wandb.Image(Image.fromarray((img_list[-1].permute(1, 2, 0).numpy()*255).astype(np.uint8))))

# Watch the distribution of gradients and weights for generators and discriminators.
if model_type == "dcgan" or model_type == "patchgan":
    netG, netD = generators[0], discriminators[0]

    run.watch(netG, log="all", log_freq=10, idx=0)
    run.watch(netD, log="all", log_freq=10, idx=1)
elif model_type == "cyclegan":
    gen_H, gen_C = generators
    disc_H, disc_C = discriminators
        
    run.watch(gen_H, log="all", log_freq=10, idx=0)
    run.watch(gen_C, log="all", log_freq=10, idx=1)
    run.watch(disc_H, log="all", log_freq=10, idx=2)
    run.watch(disc_C, log="all", log_freq=10, idx=3)

[34m[1mwandb[0m: Currently logged in as: [33mvincenttu[0m (use `wandb login --relogin` to force relogin)


### CycleGAN Training Script

In [27]:
# Training time per epoch for b_size of 128 with 39 batches: 2.5 mins/epoch
# Number of batches of size 128: 406 ☠️
# Number of cat images with the added data: 52114 ☠️

In [28]:
# Continue CycleGAN_run0 with pretrained weights (couldn't save the optimizer 😭).
# gen_C.load_state_dict(torch.load("/content/gen_C_epoch14"))
# gen_H.load_state_dict(torch.load("/content/gen_H_epoch14"))
# disc_C.load_state_dict(torch.load("/content/disc_C_epoch14"))
# disc_H.load_state_dict(torch.load("/content/disc_H_epoch14"))

<All keys matched successfully>

In [29]:
# Checker.
if model_type == "dcgan" or model_type == "patchgan":
    assert False, f"{model_type} model_type should be 'cyclegan'."
elif model_type == "cyclegan":
    gen_C, gen_H = generators
    disc_C, disc_H = discriminators
    criterion, l1_loss = criterions

for epoch in range(1, h.num_epochs + 1):
    train_D_loss, train_G_loss = [], []

    print('Epoch: {:02d}/{:02d}'.format(epoch, h.num_epochs))
    print("TRAIN")

    loop = tqdm(enumerate(train_cat_celeba_loader))
    for i, human_cat_data in loop:
        human_data = human_cat_data["human_img"].to(c.device)
        cat_data = human_cat_data["cat_img"].to(c.device)

        # Train Discriminators H and C.
        with torch.cuda.amp.autocast():
            fake_human = gen_H(cat_data)
            D_H_real = disc_H(human_data)
            D_H_fake = disc_H(fake_human.detach())
            D_H_real_loss = criterion(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = criterion(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_cat = gen_C(human_data)
            D_C_real = disc_C(cat_data)
            D_C_fake = disc_C(fake_cat.detach())
            D_C_real_loss = criterion(D_C_real, torch.ones_like(D_C_real))
            D_C_fake_loss = criterion(D_C_fake, torch.zeros_like(D_C_fake))
            D_C_loss = D_C_real_loss + D_C_fake_loss

            # Combine discriminator losses.
            errD = (D_H_loss + D_C_loss)/2

        train_D_loss.append(errD.item())

        optimizerD.zero_grad()
        d_scaler.scale(errD).backward()
        d_scaler.step(optimizerD)
        d_scaler.update()

        # Train Generators H and C.
        with torch.cuda.amp.autocast():
            # Adversarial loss for both generators.
            D_H_fake = disc_H(fake_human)
            D_C_fake = disc_C(fake_cat)
            loss_G_H = criterion(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_C = criterion(D_C_fake, torch.ones_like(D_C_fake))

            # Cycle consistency loss.
            cycle_cat = gen_C(fake_human)
            cycle_human = gen_H(fake_cat)
            cycle_cat_loss = l1_loss(cat_data, cycle_cat)
            cycle_human_loss = l1_loss(human_data, cycle_human)

            # # Identity loss (remove these for efficiency if you set lambda_identity=0).
            # identity_cat = gen_C(cat_data)
            # identity_human = gen_H(human_data)
            # identity_cat_loss = l1_loss(cat_data, identity_cat)
            # identity_human_loss = l1_loss(human_data, identity_human)

            # Combine generator losses.
            errG = (
                loss_G_C
                + loss_G_H
                + cycle_cat_loss * h.lambda_cycle
                + cycle_human_loss * h.lambda_cycle
                # + identity_human_loss * h.lambda_identity
                # + identity_cat_loss * h.lambda_identity
            )

        train_G_loss.append(errG.item())

        optimizerG.zero_grad()
        g_scaler.scale(errG).backward()
        g_scaler.step(optimizerG)
        g_scaler.update()

        loop.set_description('D_loss: {:.5f} | G_loss: {:.5f}'.format(errD.item(), errG.item()))
        loop.set_postfix(D_loss_mean=np.mean(train_D_loss), G_loss_mean=np.mean(train_G_loss))

        # Check how the generator is doing by saving G's output on a fixed image.
        if (i % 200 == 0) or (epoch % 10 == 0) and (i == len(train_cat_celeba_loader)-1):
            with torch.no_grad():
                for fixed_images in fixed_images_loader:
                    fake = gen_C(fixed_images["images"].to(c.device)).detach().cpu()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            data.append([epoch, i])
            my_table.add_data(epoch, i, 
                              wandb.Image(Image.fromarray((img_list[-1]
                                                          .permute(1, 2, 0)
                                                          .numpy()*255)
                                                          .astype(np.uint8))))

    if epoch % 5 == 0:
        torch.save(gen_C.state_dict(), g_ckpt_dir)
        torch.save(disc_C.state_dict(), d_ckpt_dir)

        artifact = wandb.Artifact(c.model_name, type='model')
        artifact.add_file(g_ckpt_dir, name=f"G_C_epoch{epoch}.pth")
        artifact.add_file(d_ckpt_dir, name=f"D_C_epoch{epoch}.pth")
        run.log_artifact(artifact)

        torch.save(gen_H.state_dict(), g_ckpt_dir)
        torch.save(disc_H.state_dict(), d_ckpt_dir)

        artifact = wandb.Artifact(c.model_name, type='model')
        artifact.add_file(g_ckpt_dir, name=f"G_H_epoch{epoch}.pth")
        artifact.add_file(d_ckpt_dir, name=f"D_H_epoch{epoch}.pth")
        run.log_artifact(artifact)

    wandb.log({"epoch": epoch, 
              "G_loss": np.mean(train_G_loss),
              "D_loss": np.mean(train_D_loss), 
              })

Epoch: 01/90
TRAIN


0it [00:00, ?it/s]

Epoch: 02/90
TRAIN


0it [00:00, ?it/s]

Epoch: 03/90
TRAIN


0it [00:00, ?it/s]

Epoch: 04/90
TRAIN


0it [00:00, ?it/s]

Epoch: 05/90
TRAIN


0it [00:00, ?it/s]

Epoch: 06/90
TRAIN


0it [00:00, ?it/s]

Epoch: 07/90
TRAIN


0it [00:00, ?it/s]

Epoch: 08/90
TRAIN


0it [00:00, ?it/s]

Epoch: 09/90
TRAIN


0it [00:00, ?it/s]

Epoch: 10/90
TRAIN


0it [00:00, ?it/s]

Epoch: 11/90
TRAIN


0it [00:00, ?it/s]

Epoch: 12/90
TRAIN


0it [00:00, ?it/s]

Epoch: 13/90
TRAIN


0it [00:00, ?it/s]

Epoch: 14/90
TRAIN


0it [00:00, ?it/s]

Epoch: 15/90
TRAIN


0it [00:00, ?it/s]

Epoch: 16/90
TRAIN


0it [00:00, ?it/s]

Epoch: 17/90
TRAIN


0it [00:00, ?it/s]

Epoch: 18/90
TRAIN


0it [00:00, ?it/s]

Epoch: 19/90
TRAIN


0it [00:00, ?it/s]

Epoch: 20/90
TRAIN


0it [00:00, ?it/s]

Epoch: 21/90
TRAIN


0it [00:00, ?it/s]

Epoch: 22/90
TRAIN


0it [00:00, ?it/s]

Epoch: 23/90
TRAIN


0it [00:00, ?it/s]

Epoch: 24/90
TRAIN


0it [00:00, ?it/s]

Epoch: 25/90
TRAIN


0it [00:00, ?it/s]

Epoch: 26/90
TRAIN


0it [00:00, ?it/s]

Epoch: 27/90
TRAIN


0it [00:00, ?it/s]

Epoch: 28/90
TRAIN


0it [00:00, ?it/s]

Epoch: 29/90
TRAIN


0it [00:00, ?it/s]

Epoch: 30/90
TRAIN


0it [00:00, ?it/s]

Epoch: 31/90
TRAIN


0it [00:00, ?it/s]

KeyboardInterrupt: ignored

### PatchGAN Training Script

In [None]:
# Checker.
if model_type == "dcgan" or model_type == "cyclegan":
    assert False, f"{model_type} model_type should be 'patchgan'."
elif model_type == "patchgan":
    netG = generators[0]
    netD = discriminators[0]
    criterion, l1_loss = criterions

for epoch in range(1, h.num_epochs + 1):
    train_D_loss, train_G_loss = [], []

    print('Epoch: {:02d}/{:02d}'.format(epoch, h.num_epochs))
    print("TRAIN")

    loop = tqdm(enumerate(train_cat_celeba_loader))
    for i, human_cat_data in loop:
        human_data = human_cat_data["human_img"].to(c.device)
        cat_data = human_cat_data["cat_img"].to(c.device)

        # Train discriminator.
        with torch.cuda.amp.autocast():
            y_fake = netG(human_data)
            D_real = netD(human_data, cat_data)
            D_real_loss = criterion(D_real, torch.ones_like(D_real))
            D_fake = netD(human_data, y_fake.detach())
            D_fake_loss = criterion(D_fake, torch.zeros_like(D_fake))
            errD = (D_real_loss + D_fake_loss) / 2

        train_D_loss.append(errD.item())

        netD.zero_grad()
        d_scaler.scale(errD).backward()
        d_scaler.step(optimizerD)
        d_scaler.update()

        # Train generator.
        with torch.cuda.amp.autocast():
            D_fake = netD(human_data, y_fake)
            G_fake_loss = criterion(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, cat_data) * h.l1_lambda
            errG = G_fake_loss + L1

        train_G_loss.append(errG.item())

        optimizerG.zero_grad()
        g_scaler.scale(errG).backward()
        g_scaler.step(optimizerG)
        g_scaler.update()

        loop.set_description('D_loss: {:.5f} | G_loss: {:.5f}'.format(errD.item(), errG.item()))
        loop.set_postfix(D_loss_mean=np.mean(train_D_loss), G_loss_mean=np.mean(train_G_loss))

        # Check how the generator is doing by saving G's output on a fixed image.
        if (i % 40 == 0) or (epoch % 10 == 0) and (i == len(train_cat_celeba_loader)-1):
            with torch.no_grad():
                for fixed_images in fixed_images_loader:
                    fake = netG(fixed_images["images"].to(c.device)).detach().cpu()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            data.append([epoch, i])
            my_table.add_data(epoch, i, 
                              wandb.Image(Image.fromarray((img_list[-1]
                                                          .permute(1, 2, 0)
                                                          .numpy()*255)
                                                          .astype(np.uint8))))

    if epoch % 10 == 0:
        torch.save(netG.state_dict(), g_ckpt_dir)
        torch.save(netD.state_dict(), d_ckpt_dir)

        artifact = wandb.Artifact(c.model_name, type='model')
        artifact.add_file(g_ckpt_dir, name=f"G_epoch{epoch}.pth")
        artifact.add_file(d_ckpt_dir, name=f"D_epoch{epoch}.pth")
        run.log_artifact(artifact)

    wandb.log({"epoch": epoch, 
              "G_loss": np.mean(train_G_loss),
              "D_loss": np.mean(train_D_loss), 
              })

### DCGAN Training Script

In [None]:
# Checker.
if model_type == "patchgan" or model_type == "cyclegan":
    assert False, f"{model_type} model_type should be 'dcgan'."
elif model_type == "dcgan":
    netG = generators[0]
    netD = discriminators[0]
    criterion = criterions[0]

for epoch in range(1, h.num_epochs + 1):
    train_D_loss, train_G_loss = [], []

    print('Epoch: {:02d}/{:02d}'.format(epoch, h.num_epochs))
    print("TRAIN")

    loop = tqdm(enumerate(train_cat_celeba_loader))
    for i, human_cat_data in loop:
        human_data = human_cat_data["human_img"].to(c.device)
        cat_data = human_cat_data["cat_img"].to(c.device)

        # Train discriminator.
        netD.zero_grad()
        real_cpu = cat_data
        b_size = real_cpu.size(0)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, torch.ones_like(output))
        errD_real.backward()
        D_x = output.mean().item()

        noise = human_data
        fake = netG(noise)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, torch.zeros_like(output))
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        train_D_loss.append(errD.item())

        # Train generator.
        netG.zero_grad()
        output = netD(fake).view(-1)
        errG = criterion(output, torch.ones_like(output))
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        train_G_loss.append(errG.item())

        loop.set_description('D_loss: {:.5f} | G_loss: {:.5f}'.format(errD.item(), errG.item()))
        loop.set_postfix(D_loss_mean=np.mean(train_D_loss), G_loss_mean=np.mean(train_G_loss))

        # Check how the generator is doing by saving G's output on a fixed image.
        if (i % 40 == 0) or (epoch % 10 == 0) and (i == len(train_cat_celeba_loader)-1):
            with torch.no_grad():
                for fixed_images in fixed_images_loader:
                    fake = netG(fixed_images["images"].to(c.device)).detach().cpu()
                    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            data.append([epoch, i])
            my_table.add_data(epoch, i, 
                              wandb.Image(Image.fromarray((img_list[-1]
                                                          .permute(1, 2, 0)
                                                          .numpy()*255)
                                                          .astype(np.uint8))))

    if epoch % 10 == 0:
        torch.save(netG.state_dict(), g_ckpt_dir)
        torch.save(netD.state_dict(), d_ckpt_dir)

        artifact = wandb.Artifact(c.model_name, type='model')
        artifact.add_file(g_ckpt_dir, name=f"G_epoch{epoch}.pth")
        artifact.add_file(d_ckpt_dir, name=f"D_epoch{epoch}.pth")
        run.log_artifact(artifact)

    wandb.log({"epoch": epoch, 
              "G_loss": np.mean(train_G_loss),
              "D_loss": np.mean(train_D_loss), 
              })

### Log and Garbage Collect

In [30]:
# Run this block if you ended the training script prematurely.

img_frames = [(frame*255).permute(1, 2, 0).numpy().astype("uint8") for frame in img_list]
if len(img_frames): imageio.mimsave(f'gifs/ani.gif', img_frames)

run.log({f"Fixed Images Inspection": my_table})
run.log({'Fixed Images Animation': 
         [wandb.Image(f'gifs/ani.gif')]})

if model_type == "dcgan" or model_type == "patchgan":
    del netG, netD
elif model_type == "cyclegan":
    del gen_C, gen_H, disc_C, disc_H 
del optimizerD, optimizerG
gc.collect()
torch.cuda.empty_cache()

run.finish()

VBox(children=(Label(value=' 703.45MB of 703.45MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.9999974…

0,1
D_loss,█▅▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
G_loss,█▇▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███

0,1
D_loss,0.1513
G_loss,4.17722
epoch,30.0


# References

###### Datasets

- [Animal Faces](https://www.kaggle.com/andrewmvd/animal-faces) by Larxel
- [My Kaggle Dataset Version](https://www.kaggle.com/vincenttu/catfacesdatasetfferlito?select=dataset-part1) and [GitHub Version of Cat Faces Dataset](https://github.com/fferlito/Cat-faces-dataset) by fferlito
- [Cats faces 64x64 (For generative models)](https://www.kaggle.com/spandan2/cats-faces-64x64-for-generative-models) by Spandan
- [cat2dog](https://www.kaggle.com/waifuai/cat2dog) by waifuai
- [celeba](https://www.kaggle.com/zuozhaorui/celeba) by Zhuo Zhaorui
- [FFHQ Face Data Set](https://www.kaggle.com/greatgamedota/ffhq-face-data-set) by GreatGameDota (ported from NVLabs); **note** not used yet
- [celeba-hq](https://www.kaggle.com/lamsimon/celebahq) by Lam Simon; **note** not used yet

# Archive

## Archived Preprocessing

## Animal-Faces Dataset Preprocessing (archived)

Reference: https://www.kaggle.com/andrewmvd/animal-faces

Dataset Debrief:

---

- Expected shape of (512, 512, 3)


- train/cat: 5153 files jpg same shape (I presume)
- train/dog: 4739 files jpg same shape (I presume)
- train/wild (foxes): 4738 files jpg same shape (I presume)
- val/cat: 500 files jpg same shape (I presume)
- val/dog: 500 files jpg same shape (I presume)
- val/wild (foxes): 500 files jpg same shape (I presume)

In [None]:
train_cat_path = r"/content/animal-faces/afhq/train/cat"
train_dog_path = r"/content/animal-faces/afhq/train/dog"
train_wild_path = r"/content/animal-faces/afhq/train/wild"

val_cat_path = r"/content/animal-faces/afhq/val/cat"
val_dog_path = r"/content/animal-faces/afhq/val/dog"
val_wild_path = r"/content/animal-faces/afhq/val/wild"

### Checks and Assumptions

In [None]:
'''

# Checking file counts.
assert len(list(os.listdir(train_cat_path))) == 5153, "There aren't 5153 train cat files!"
assert len(list(os.listdir(train_dog_path))) == 4739, "There aren't 4739 train dog files!"
assert len(list(os.listdir(train_wild_path))) == 4738, "There aren't 4738 train wild files!"

assert len(list(os.listdir(val_cat_path))) == 500, "There aren't 500 val cat files!"
assert len(list(os.listdir(val_dog_path))) == 500, "There aren't 500 val dog files!"
assert len(list(os.listdir(val_wild_path))) == 500, "There aren't 500 val wild files!"

# Checking the extensions.
def check_extensions(path, extension="jpg"):
  for idx, img_path in enumerate(glob(os.path.join(path, "*"))):
    assert extension in img_path, f"Index {idx} does not have the {extension} extension!"
  
for p in [train_cat_path, train_dog_path, train_wild_path, 
          val_cat_path, val_dog_path, val_wild_path]:
  check_extensions(p)

# Checking image sizes.
def check_img_size(path, expected_img_shape=(512, 512, 3)):
  for idx, img_path in enumerate(glob(os.path.join(path, "*"))):
    assert np.asarray(Image.open(img_path)).shape == expected_img_shape, f"Index {idx} does not have the expected shape of {expected_img_shape}!"

for p in [train_cat_path, train_dog_path, train_wild_path, 
          val_cat_path, val_dog_path, val_wild_path]:
  check_img_size(p)

'''

'\n\n# Checking file counts.\nassert len(list(os.listdir(train_cat_path))) == 5153, "There aren\'t 5153 train cat files!"\nassert len(list(os.listdir(train_dog_path))) == 4739, "There aren\'t 4739 train dog files!"\nassert len(list(os.listdir(train_wild_path))) == 4738, "There aren\'t 4738 train wild files!"\n\nassert len(list(os.listdir(val_cat_path))) == 500, "There aren\'t 500 val cat files!"\nassert len(list(os.listdir(val_dog_path))) == 500, "There aren\'t 500 val dog files!"\nassert len(list(os.listdir(val_wild_path))) == 500, "There aren\'t 500 val wild files!"\n\n# Checking the extensions.\ndef check_extensions(path, extension="jpg"):\n  for idx, img_path in enumerate(glob(os.path.join(path, "*"))):\n    assert extension in img_path, f"Index {idx} does not have the {extension} extension!"\n  \nfor p in [train_cat_path, train_dog_path, train_wild_path, \n          val_cat_path, val_dog_path, val_wild_path]:\n  check_extensions(p)\n\n# Checking image sizes.\ndef check_img_size(pa

### Building the Dataset

In [None]:
class AFCatDataset(Dataset):
    def __init__(self, path, image_size, mode):
        super(AFCatDataset, self).__init__()

        self.path = list(glob(os.path.join(path, "*")))
        self.image_size = image_size
        assert mode in ['train', 'valid']
        self.mode = mode

        if self.mode == "train":
          self.transform = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.CenterCrop(self.image_size[0], self.image_size[1]),
            A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ToTensorV2(),
          ])
        else:
          self.transform = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.CenterCrop(self.image_size[0], self.image_size[1]),
            A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ToTensorV2(),
          ])

    def __len__(self):
        return len(self.path)

    def __getitem__(self, index):
          img = np.asarray(Image.open(self.path[index]))
          img = self.transform(image=img)["image"]

          return {"images": img,
                  }

## CelebA Dataset Preprocessing (archived)

Reference: https://www.kaggle.com/zuozhaorui/celeba

Dataset Debrief:

---

- Expected shape of (218, 178, 3)


- img_align_celeba/img_align/celeba: ~203k

In [None]:
celeba_path = r"/content/celeba/img_align_celeba/img_align_celeba"

### Checks and Assumptions

In [None]:
'''

# Checking file counts.
assert len(list(os.listdir(celeba_path))) == 202599, "There aren't 202599 files!"

# Checking the extensions.
for p in [celeba_path]:
  check_extensions(p)

# Checking image sizes.
for p in [celeba_path]:
  check_img_size(p, expected_img_shape=(218, 178, 3))

'''

'\n\n# Checking file counts.\nassert len(list(os.listdir(celeba_path))) == 202599, "There aren\'t 202599 files!"\n\n# Checking the extensions.\nfor p in [celeba_path]:\n  check_extensions(p)\n\n# Checking image sizes.\nfor p in [celeba_path]:\n  check_img_size(p, expected_img_shape=(218, 178, 3))\n\n'

### Building the Dataset

In [None]:
class CelebADataset(Dataset):
    def __init__(self, path, image_size, mode):
        super(CelebADataset, self).__init__()

        self.path = list(glob(os.path.join(path, "*")))  # We will slice this dataset so the num images will match the AFCatDataset.
        # That means we will set shuffle=True when we create the Dataloader for CelebA.
        self.image_size = image_size
        assert mode in ['train', 'valid']
        self.mode = mode

        if self.mode == "train":
          self.transform = A.Compose([
            # A.HorizontalFlip(p=0.5),
            # A.ColorJitter(p=0.2),

            A.Resize(self.image_size[0], self.image_size[1]),
            A.CenterCrop(self.image_size[0], self.image_size[1]),
            A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ToTensorV2(),
          ])
        else:
          self.transform = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.CenterCrop(self.image_size[0], self.image_size[1]),
            A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ToTensorV2(),
          ])

    def __len__(self):
        return len(self.path)

    def __getitem__(self, index):
          img = np.asarray(Image.open(self.path[index]))
          img = self.transform(image=img)["image"]

          return {"images": img,
                  }