In [1]:
from pathlib import Path
USE_COLAB: bool = True
dataset_base_path = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets/FakeFaces/WGAN-GP")
if USE_COLAB:
  from google.colab import drive
  
  # Mount the drive to access google shared docs
  drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [6]:
import torch
import torch.nn as nn
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np

In [3]:
class Generator(nn.Module):
  def __init__(self, latent_dim: int = 100, ngf: int = 64, n_channels: int = 3):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(latent_dim, ngf * 8, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0)),
        nn.BatchNorm2d(ngf * 8),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
        nn.BatchNorm2d(ngf * 4),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
        nn.BatchNorm2d(ngf * 2),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
        nn.BatchNorm2d(ngf),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(ngf, n_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
        nn.Tanh(),
    )

  def forward(self, input):
    return self.main(input)

In [4]:
import re
def get_latest_pth_model(base_path) -> Path:
  epoch_num = []
  all_files = sorted(Path(base_path).glob("*.pth"))
  for file_ in all_files:
    idx_num = re.search("--", str(file_)).span()
    idx_pt = re.search(".pt", str(file_)).span()
    epoch_num.append(int(str(file_)[idx_num[-1]:idx_pt[0]]))

  idx = epoch_num.index(np.max(epoch_num))
  return all_files[idx]

In [7]:
model_dir = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Code/DatasetGeneration/WGANGP/models")
model_path = get_latest_pth_model(model_dir)
print(f"model_path: '{model_path}'")

config = {
    "latent_dim": 100,
    "ngf": 64,
    "n_channels": 3,
    "n_imgs_to_generate": 40000,
    "batch_size": 1,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator(config["latent_dim"], config["ngf"], config["n_channels"])
checkpoint = torch.load(str(model_path))
generator.load_state_dict(checkpoint["Generator"])
generator.to(device)

model_path: '/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Code/DatasetGeneration/WGANGP/models/WGANGP--30.pth'


Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.2)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)

In [8]:
output_imgs_path = dataset_base_path
import shutil
if output_imgs_path.exists():
  shutil.rmtree(str(output_imgs_path))
output_imgs_path.mkdir(exist_ok=True, parents=True)
print(output_imgs_path)

/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets/FakeFaces/WGAN-GP


In [10]:
torch.manual_seed(999)
if config["n_imgs_to_generate"] % config["batch_size"] != 0:
  raise RuntimeError(f"n_imgs_to_generate not divisible by batch_size")
iterations = int(config["n_imgs_to_generate"] / config["batch_size"])
img_cnt = 0
generator.eval()
for _ in tqdm(range(iterations)):
  with torch.no_grad():
    fakes = generator(torch.randn(config["batch_size"], config["latent_dim"], 1, 1, dtype=torch.float, device=device))
    for fake in fakes:
      output_path = output_imgs_path / f"{img_cnt}.jpg"
      save_image(fake, output_path)
      img_cnt += 1

drive.flush_and_unmount()

100%|██████████| 40000/40000 [04:41<00:00, 142.08it/s]
