In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image

In [2]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        #Input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # img 4 x 4
        self._block(features_g*16, features_g*8, 4, 2, 1), # img 8x8
        self._block(features_g*8, features_g*4, 4, 2, 1), # img 16x16
        self._block(features_g*4, features_g*2, 4, 2, 1 ), #img 32x32
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), #N x channels_img x 64 x 64
        nn.Tanh() #model aligns with normalzed images between -1 and 1
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.gen(x)

In [3]:
#Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4 
Z_DIM = 100
IMG_SIZE = 64
CHANNELS_IMG = 1
batch_size = 64
num_epochs = 20
features_critic = 64
features_gen = 64
critic_iterations = 5
lambda_gp = 10

In [4]:
output_gen = Generator(Z_DIM, CHANNELS_IMG, features_gen).to(device)
output_gen.load_state_dict(torch.load("../input/final-dataset/gen_3.pth", map_location=torch.device("cpu")))
output_gen.eval()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 1, 

In [5]:
for n in range(10000):
    output_noise = torch.randn(1, Z_DIM, 1, 1).to(device)
    image = output_gen(output_noise).reshape(-1, 1, 64, 64)
    torchvision.utils.save_image(image, f"./test.png")

    resize = transforms.Compose([transforms.Scale((28,28))])

    img = Image.open('./test.png')
    resized_test = resize(img)
    resized_test.save(f'./{n + 1}.png')

