In [None]:
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.utils as vutils

image_size = 64
nc = 3
nz = 100
ngf = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 
            nn.BatchNorm2d(ngf * 8), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ngf * 4), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ngf * 2), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ngf), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


netG = Generator().to(device)
netG.load_state_dict(torch.load('../checkpoint/dcgan_checkpoint.pth'))

In [None]:
def seed_given_generator(seed, va):
    torch.manual_seed(seed)
    test_batch_size = 64
    noise = torch.randn(test_batch_size, nz, 1, 1, device=device)
    
    if va is not None:
        noise = noise + va

    with torch.no_grad():
        fake = netG(noise).detach().cpu()
    vis = vutils.make_grid(fake, padding=2, normalize=True)
    vis = (vis.numpy() * 255).astype(np.uint8).transpose(1, 2, 0)

    return vis

In [None]:
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Result 42')
plt.imshow(seed_given_generator(42, None))
plt.show()

In [None]:
torch.manual_seed(42)
noise = torch.randn(64, nz, 1, 1, device=device)

male = [2, 11, 12, 21, 22, 26, 27, 29, 33, 62]
female = [0, 4, 10, 15, 20, 24, 28, 36, 41, 45, 56, 58]

va = noise[male].mean(dim=0) - noise[female].mean(dim=0)

In [None]:
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Result 999')
plt.imshow(seed_given_generator(999, None))
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Result 999 Male')
plt.imshow(seed_given_generator(999, va))
plt.show()