In [1]:
from pathlib import Path
USE_COLAB: bool = True
dataset_base_path = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets/FakeFaces/CDCGAN")
if USE_COLAB:
  from google.colab import drive
  
  # Mount the drive to access google shared docs
  drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [2]:
from typing import Tuple, Union, List, Optional, Callable
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
from tqdm import tqdm
from dataclasses import dataclass
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torchvision.utils import save_image, make_grid
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from tqdm import tqdm
from zipfile import ZipFile
from PIL import Image
import pandas as pd

In [3]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.deconv = nn.Sequential(
        nn.ConvTranspose2d(in_channels=104, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        nn.BatchNorm2d(256),
        nn.Tanh(),

        nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        nn.BatchNorm2d(128),
        nn.Tanh(),

        nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
        nn.BatchNorm2d(64),
        nn.Tanh(),

        nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(4, 4), stride=(2, 2), padding=(0, 0)),
        nn.BatchNorm2d(32),
        nn.Tanh(),

        nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
        # nn.BatchNorm2d(3),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.deconv(x)

In [4]:
import re
def get_latest_pth_model(base_path) -> Path:
  epoch_num = []
  all_files = sorted(Path(base_path).glob("*.pth"))
  for file_ in all_files:
    idx_num = re.search("--", str(file_)).span()
    idx_pt = re.search(".pt", str(file_)).span()
    epoch_num.append(int(str(file_)[idx_num[-1]:idx_pt[0]]))

  idx = epoch_num.index(np.max(epoch_num))
  return all_files[idx]

In [5]:
attr_file = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets/RealFaces/CelebA/img_align_celeba_attr.csv")
attrs = pd.read_csv(str(attr_file), index_col=0)

config = {
    "batch_size": 1,
    "latent_dim": 64,
    "img_size": 64,
    "n_imgs_to_generate": 40000,
}

model_dir = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Code/DatasetGeneration/CDCGAN/models")
model_path = get_latest_pth_model(model_dir)
print(f"model_path: '{model_path}'")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator()
checkpoint = torch.load(str(model_path))
generator.load_state_dict(checkpoint["Generator"])
generator.to(device)

model_path: '/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Code/DatasetGeneration/CDCGAN/models/CDCGAN--31.pth'


Generator(
  (deconv): Sequential(
    (0): ConvTranspose2d(104, 256, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
    (3): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Tanh()
    (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Tanh()
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2))
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Tanh()
    (12): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
  )
)

In [6]:
output_imgs_path = dataset_base_path
import shutil
if output_imgs_path.exists():
  shutil.rmtree(str(output_imgs_path))
output_imgs_path.mkdir(exist_ok=True, parents=True)
print(output_imgs_path)

/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets/FakeFaces/CDCGAN


In [7]:
np.random.seed(999)
attr_indices = list(np.random.randint(low=1, high=len(attrs) + 1, size=(config["n_imgs_to_generate"])))
attrs_rand = torch.Tensor(attrs.loc[attr_indices].values)

In [22]:
from torchvision.utils import save_image
from tqdm import tqdm

torch.manual_seed(999)
if config["n_imgs_to_generate"] % config["batch_size"] != 0:
  raise RuntimeError(f"n_imgs_to_generate not divisible by batch_size")
iterations = int(config["n_imgs_to_generate"] / config["batch_size"])
img_cnt = 0
generator.eval()
for idx in tqdm(range(iterations)):
  with torch.no_grad():
    z = torch.randn(config["batch_size"], config["latent_dim"])
    z_cat = torch.concat([z, attrs_rand[idx].unsqueeze(0)], dim=1).unsqueeze(-1).unsqueeze(-1).to(device)
    fakes = generator(z_cat)
    for fake in fakes:
      output_path = output_imgs_path / f"{img_cnt}.jpg"
      save_image(fake, output_path)
      img_cnt += 1

drive.flush_and_unmount()

100%|██████████| 40000/40000 [04:52<00:00, 136.68it/s]
