# Generate images using a pretrained Model

In [1]:
import torch
import matplotlib.pyplot as plt

from datetime import datetime
from torchvision import transforms
from torchvision.utils import make_grid

from pl_module import PL_Module
from model import Generator, Critic

## Hyperparameters

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nz = 100
ngf = 64
ndf = 64

weights_path = "./data/weights/wgan.weights.pth"

## Initialization

In [None]:
# reverse normalization to generate images in original scope
reverse_normalize = transforms.Compose([
    transforms.Normalize(mean=[0, 0, 0], std=[1/0.5, 1/0.5, 1/0.5]),
    transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1, 1, 1])
])

# create models
generator = Generator(nz=nz, ngf=ngf)
critic = Critic(ndf=ndf)

# create module to load in the weights
module = PL_Module(generator, critic, None, None, None, nz, 16)

# load checkpoint
checkpoint = torch.load(weights_path, map_location=device)
state_dict = checkpoint['state_dict']

# load weights
module.load_state_dict(state_dict)

generator = module.generator

# set generator to eval mode
generator.eval().to(device)

## Image generation

### Single image

In [None]:
# input noise vector
z = torch.randn((1, nz, 1, 1), device=device)

# generate image
with torch.no_grad():
    generated_image = generator(z)

generate_image = reverse_normalize(generated_image[0])

# display image
plt.imshow(generated_image[0].permute(1, 2, 0).cpu())
plt.axis('off')

### Grid

In [None]:
# number of images for grid (6x6)
n = 6

# input noise vector
z = torch.randn((n*n, nz, 1, 1), device=device)

# generate images
with torch.no_grad():
    generated_images = generator(z)

# create grid of images
grid = make_grid(generated_images, nrow=n, normalize=True).permute(1, 2, 0).cpu()

# display the grid
plt.imshow(grid)
plt.axis('off')