# CcGAN Loaded

This notebook is used for loading CcGAN and generating images to `.h5` datasets.

## Step 1 - Import and Settings

Import others' libraries.

In [1]:
import h5py
import numpy as np
import os
import random
import torch.backends.cudnn as cudnn

Import our own libraries.

In [2]:
from models import *

Set hyperparameters.

In [3]:
# Overall Settings
DATA_PATH = "./datasets/Ra_128_indexed.h5"
GENERATE_DATASETS_DIR = "./output/generated_datasets"
EMBED_MODELS_DIR = "./output/embed_models"
CCGAN_MODELS_DIR = "./output/CcGAN_models"
SEED = 42
NUM_WORKERS = 0

# Dataset
DATA_SPLIT = "train"
MIN_LABEL = 1.3
MAX_LABEL = 5.2
NUM_CHANNELS = 3
IMG_SIZE = 128
MAX_NUM_IMG_PER_LABEL = 1000
MAX_NUM_IMG_PER_LABEL_AFTER_REPLICA = 0
SHOW_REAL_IMGS = True
VISUALIZE_FAKE_IMAGES = True

# Embedding Settings
NET_EMBED = "ResNet34_embed"
DIM_EMBED = 128

# Embedding Training Settings
BASE_LR_X2Y = 0.01
BASE_LR_Y2H = 0.01
EPOCH_NET_EMBED = 200
RESUME_EPOCH_NET_EMBED = 0
EPOCH_NET_Y2H = 500
BATCH_SIZE_EMBED = 256

# GAN Settings
GAN = "CcGAN"
GAN_ARCH = "SAGAN"
LOSS_TYPE_GAN = "hinge"
DIM_GAN = 128
CGAN_NUM_CLASSES = 20
KERNEL_SIGMA = 0.04697700151079382
THRESHOLD_TYPE = "soft"
KAPPA = 108.73916897823854

# GAN Training Settings
NITERS_GAN = 60000
RESUME_NITERS_GAN = 0
SAVE_NITERS_FREQ = 5000
LR_G = 1e-4
LR_D = 1e-4
BATCH_SIZE_DISC = 64
BATCH_SIZE_GENE = 64
NUM_D_STEPS = 4
VISUALIZE_FREQ = 1000
NONZERO_SOFT_WEIGHT_THRESHOLD = 1e-3

# DiffAugment Settings
GAN_DIFFAUGMENT = True
GAN_DIFFAUGMENT_POLICY = "color,translation,cutout"

# Generate Settings
NUM_GENERATE = 9192
BATCH_SIZE_FOR_GENERATE = 64

Settings.

In [4]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Seeds
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
np.random.seed(SEED)

# Paths - From arguments
if not os.path.exists(GENERATE_DATASETS_DIR):
    os.makedirs(GENERATE_DATASETS_DIR, exist_ok=True)
if not os.path.exists(EMBED_MODELS_DIR):
    os.makedirs(EMBED_MODELS_DIR, exist_ok=True)
if not os.path.exists(CCGAN_MODELS_DIR):
    os.makedirs(CCGAN_MODELS_DIR, exist_ok=True)

# Paths - Embedding
EMBED_X2Y_PATH = os.path.join(
    EMBED_MODELS_DIR,
    f"embed_x2y_{NET_EMBED}_dim_{DIM_EMBED}_batchSize_{BATCH_SIZE_EMBED}_lr_{BASE_LR_X2Y:.0e}_epoch_{EPOCH_NET_EMBED}_seed_{SEED}.pth",
)
Y2H_PATH = os.path.join(
    EMBED_MODELS_DIR,
    f"y2h_{NET_EMBED}_dim_{DIM_EMBED}_batchSize_{BATCH_SIZE_EMBED}_lr_{BASE_LR_Y2H:.0e}_epoch_{EPOCH_NET_Y2H}_seed_{SEED}.pth",
)

# Paths - CcGAN
CCGAN_PATH = os.path.join(
    CCGAN_MODELS_DIR,
    f"CcGAN_{GAN_ARCH}_dim_{DIM_GAN}_{IMG_SIZE}_batchSizeG_{BATCH_SIZE_GENE}_batchSizeD_{BATCH_SIZE_DISC}_lrG_{LR_G:.0e}_lrD_{LR_D:.0e}_nIters_{NITERS_GAN}_nDsteps_{NUM_D_STEPS}_{THRESHOLD_TYPE}_{KERNEL_SIGMA:.3f}_{KAPPA:.3f}_loss_{LOSS_TYPE_GAN}_seed_{SEED}.pth",
)

# Paths - Generate Dataset
GENERATE_DATASETS_PATH = os.path.join(
    GENERATE_DATASETS_DIR,
    f"CcGAN_{GAN_ARCH}_dim_{DIM_GAN}_{IMG_SIZE}_batchSizeG_{BATCH_SIZE_GENE}_batchSizeD_{BATCH_SIZE_DISC}_lrG_{LR_G:.0e}_lrD_{LR_D:.0e}_nIters_{NITERS_GAN}_nDsteps_{NUM_D_STEPS}_{THRESHOLD_TYPE}_{KERNEL_SIGMA:.3f}_{KAPPA:.3f}_loss_{LOSS_TYPE_GAN}_seed_{SEED}_in_generate.h5",
)

## Step 2 - Load Pretrained Embedding Models

In this step we load `net_embed` and `net_y2h` models to CPU.

In [5]:
if NET_EMBED == "ResNet18_embed":
    net_embed = ResNet18_embed(dim_embed=DIM_EMBED)
elif NET_EMBED == "ResNet34_embed":
    net_embed = ResNet34_embed(dim_embed=DIM_EMBED)
elif NET_EMBED == "ResNet50_embed":
    net_embed = ResNet50_embed(dim_embed=DIM_EMBED)
net_embed = net_embed.to(device)
net_embed = nn.DataParallel(net_embed)

net_y2h = model_y2h(dim_embed=DIM_EMBED)
net_y2h = net_y2h.to(device)
net_y2h = nn.DataParallel(net_y2h)

## (1). Load net_embed first: x2h+h2y
checkpoint = torch.load(EMBED_X2Y_PATH, map_location=device)
net_embed.load_state_dict(checkpoint["net_state_dict"])

## (2). Load y2h
checkpoint = torch.load(Y2H_PATH, map_location=device)
net_y2h.load_state_dict(checkpoint["net_state_dict"])

net_embed.eval()
net_h2y = net_embed.module.h2y
net_y2h.eval()

DataParallel(
  (module): model_y2h(
    (main): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): GroupNorm(8, 128, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): GroupNorm(8, 128, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): GroupNorm(8, 128, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): GroupNorm(8, 128, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Linear(in_features=128, out_features=128, bias=True)
      (13): ReLU()
    )
  )
)

## Step 3 - Load CcGAN

This step loads `netG` from `.pth` file, and switch it to evaluation mode.

In [6]:
checkpoint = torch.load(CCGAN_PATH, map_location=device)
netG = CcGAN_SAGAN_Generator(dim_z=DIM_GAN, dim_embed=DIM_EMBED).to(device)
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint["netG_state_dict"])
netG.eval()

DataParallel(
  (module): CcGAN_SAGAN_Generator(
    (snlinear0): Linear(in_features=128, out_features=16384, bias=True)
    (block1): GenBlock(
      (cond_bn1): ConditionalBatchNorm2d(
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.001, affine=False, track_running_stats=True)
        (embed_gamma): Linear(in_features=128, out_features=1024, bias=False)
        (embed_beta): Linear(in_features=128, out_features=1024, bias=False)
      )
      (relu): ReLU(inplace=True)
      (snconv2d1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (cond_bn2): ConditionalBatchNorm2d(
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.001, affine=False, track_running_stats=True)
        (embed_gamma): Linear(in_features=128, out_features=1024, bias=False)
        (embed_beta): Linear(in_features=128, out_features=1024, bias=False)
      )
      (snconv2d2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (snconv2d0): Conv2d(1024, 1024

## Step 4 - Generate Datasets

Define a function to generate a batch of images `generate_batch_images`.

Define a function to generate all images `generate_images`.

In [None]:
def generate_batch_images(labels, start_index, batch_size) -> np.ndarray:

    # All fake images stored here
    fake_images_ndarray = []

    # Labels to generate
    if start_index + batch_size > labels.shape[0]:
        labels_to_generate_raw = np.array(labels[start_index:])
    else:
        labels_to_generate_raw = np.array(
            labels[start_index : (start_index + batch_size)]
        )

    # Calculate
    labels_to_generate = (labels_to_generate_raw - MIN_LABEL) / (MAX_LABEL - MIN_LABEL)
    z = torch.randn(labels_to_generate.shape[0], IMG_SIZE, dtype=torch.float).to(device)
    y = torch.from_numpy(labels_to_generate).type(torch.float).reshape(-1, 1).to(device)

    # Generate and save fake images
    batch_fake_images = netG(z, net_y2h(y))
    batch_fake_images = batch_fake_images.detach().cpu()
    for j in range(batch_fake_images.shape[0]):
        fake_image = batch_fake_images[j].permute(1, 2, 0)
        fake_images_ndarray.append(fake_image.numpy())

    # Transfer to `numpy.ndarray`
    fake_images_ndarray = np.array(fake_images_ndarray)

    # Return
    return fake_images_ndarray


def generate_images(labels) -> np.ndarray:

    # Init fake images array
    fake_images_ndarray = []

    # Loop
    start_index = 0
    while start_index < labels.shape[0]:
        print(
            f"Generating images from index [{start_index}, {min(labels.shape[0], start_index + BATCH_SIZE_FOR_GENERATE) - 1}], range [{labels[start_index]:.3f}, {labels[min(labels.shape[0], start_index + BATCH_SIZE_FOR_GENERATE) - 1]:.3f}]."
        )
        fake_images_ndarray.append(
            generate_batch_images(labels, start_index, BATCH_SIZE_FOR_GENERATE)
        )
        start_index += BATCH_SIZE_FOR_GENERATE

    # Transfer to `numpy.ndarray`
    fake_images_ndarray = np.concatenate(fake_images_ndarray, axis=0)

    # Return
    return fake_images_ndarray

Use function `generate_images` function to generate images. (format: `numpy.ndarray`)

In [8]:
# Linearly calculate labels
labels = np.linspace(MIN_LABEL, MAX_LABEL, num=NUM_GENERATE).astype(float)

# Store images `numpy.ndarray`
fake_images_ndarray = generate_images(labels)

Generating images from index [0, 63], range [1.300, 1.327].
Generating images from index [64, 127], range [1.327, 1.354].
Generating images from index [128, 191], range [1.354, 1.381].
Generating images from index [192, 255], range [1.381, 1.408].
Generating images from index [256, 319], range [1.409, 1.435].
Generating images from index [320, 383], range [1.436, 1.463].
Generating images from index [384, 447], range [1.463, 1.490].
Generating images from index [448, 511], range [1.490, 1.517].
Generating images from index [512, 575], range [1.517, 1.544].
Generating images from index [576, 639], range [1.544, 1.571].
Generating images from index [640, 703], range [1.572, 1.598].
Generating images from index [704, 767], range [1.599, 1.625].
Generating images from index [768, 831], range [1.626, 1.653].
Generating images from index [832, 895], range [1.653, 1.680].
Generating images from index [896, 959], range [1.680, 1.707].
Generating images from index [960, 1023], range [1.707, 1.7

Make `.h5` file.

In [9]:
# Convert all images to [0, 255] uint8 and clip values
fake_images_ndarray = np.clip((fake_images_ndarray * 0.5 + 0.5) * 255, 0, 255).astype(
    "uint8"
)

with h5py.File(GENERATE_DATASETS_PATH, "w") as f:
    f.create_dataset("images", data=fake_images_ndarray, dtype="uint8")
    f.create_dataset("index_train", data=list(range(0, NUM_GENERATE, 2)), dtype="int64")
    f.create_dataset("index_valid", data=list(range(1, NUM_GENERATE, 2)), dtype="int64")
    f.create_dataset("labels", data=labels, dtype="float64")
    f.create_dataset("types", data=np.zeros(NUM_GENERATE, dtype="int32"), dtype="int32")
print(f"`.h5` file saved in {GENERATE_DATASETS_PATH}.")

`.h5` file saved in ./output/generated_datasets/CcGAN_SAGAN_dim_128_128_batchSizeG_64_batchSizeD_64_lrG_1e-04_lrD_1e-04_nIters_60000_nDsteps_4_soft_0.047_108.739_loss_hinge_seed_42_in_generate.h5.


Use function `view_dataset` to check the generated dataset.

In [10]:
def _print_hdf5(name, obj):
    indent = "  " * name.count("/")
    if isinstance(obj, h5py.Dataset):
        print(f"{indent}[Dataset] {name} shape={obj.shape} dtype={obj.dtype}")
    elif isinstance(obj, h5py.Group):
        print(f"{indent}[Group]   {name}")


def view_dataset(dataset_path):
    with h5py.File(dataset_path, "r") as f:
        f.visititems(_print_hdf5)


view_dataset(GENERATE_DATASETS_PATH)

[Dataset] images shape=(9192, 128, 128, 3) dtype=uint8
[Dataset] index_train shape=(4596,) dtype=int64
[Dataset] index_valid shape=(4596,) dtype=int64
[Dataset] labels shape=(9192,) dtype=float64
[Dataset] types shape=(9192,) dtype=int32
