# Testing Generative Adversarial Network (GAN)

### Imports

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image

from training.generator import Generator
from training import utils
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Import the model

In [None]:
MODEL_PATH = 'models/faces_256.pt'

model = Generator().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location = DEVICE))

### Mean W image

In [None]:
image = model.generate_one(0.0)

plt.imshow(image)
plt.axis('off')
plt.show()

### Settings

In [None]:
SAVE_PATH = None	# Path to save the generated image (None to not save)
SHAPE = (8, 6)		# Shape of the grid of images
PSI = 0.7			# Psi value to use for the generation

### Tests

In [None]:
images = model.generate_grid(SHAPE, PSI)

if SAVE_PATH is not None:
	Image.fromarray(images).save(SAVE_PATH)

plt.figure(figsize = (20, 20))
plt.imshow(images)
plt.axis('off')
plt.show()