In [1]:
import os

from torch import nn
from torchvision.utils import make_grid, save_image

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import torch
from torchvision import transforms

## DC GANs

In [2]:
device = torch.device("cuda:0")

In [3]:
GENERATOR_FEATURES_SIZE = 64
IMG_CHANNELS = 3

class Generator(nn.Module):
    def __init__(self, latent_space_dim):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_space_dim, GENERATOR_FEATURES_SIZE * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURES_SIZE * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(GENERATOR_FEATURES_SIZE * 8, GENERATOR_FEATURES_SIZE * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURES_SIZE * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( GENERATOR_FEATURES_SIZE * 4, GENERATOR_FEATURES_SIZE * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURES_SIZE * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( GENERATOR_FEATURES_SIZE * 2, GENERATOR_FEATURES_SIZE, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GENERATOR_FEATURES_SIZE),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( GENERATOR_FEATURES_SIZE, IMG_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

In [4]:
batch_size = 64
dc_gans = [file for file in os.listdir("GANs") if file.startswith("DC-GAN")]
for dc_gan in dc_gans:
    latent_dim = int(dc_gan.split("_")[-1])
    netG = Generator(latent_dim).to(device)
    netG.load_state_dict(torch.load(os.path.join("GANs", dc_gan, "generator")))
    fixed_noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    output = (netG(fixed_noise) + 1) / 2
    grid = make_grid(output, nrow=8)
    save_image(grid, os.path.join("generated_images", f"{dc_gan}_grid.png"))