In [95]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


# Libraries

In [96]:
# torch stuff
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

# needed for model
from math import log2
from tqdm import tqdm
import random
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
import torchvision
from torchvision.utils import save_image

# SAGAN Variables

In [97]:
datapath= "/home/dsu/Desktop/Projects/FinalProject/cat_imgs"

gen_save_path = "gen_sagan.pth"
dis_save_path = "dis_sagan.pth"

if(torch.cuda.is_available()):
    device = "cuda"
else:
    print("Warning! Using CPU!")
    device = "cpu"

# If load_model is False, model will train. 
# If load_model is true, no training will happen, only display generated images
save_model = True
load_model = False

lr = 1e-4
batch_size = 64

noise_dim = 100
img_size = 64
channels = 3

relu_slope = 0.2
dropout = 0.3
momentum = 0.1 # 0.9?

# hinge loss used, not wasserstein loss
ws_loss = False

total_epochs = 100

# Self Attention

In [98]:
class Self_Attn(nn.Module):
    def __init__(self, in_dim):
        super(Self_Attn, self).__init__()

        # query and key // 8 to save time, value must output the same dimension
        # for my layers to continue working
        self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        bat_size, chan, width, height = x.shape

        query = self.query(x).view(bat_size, -1, width * height).permute(0, 2, 1)
        key = self.key(x).view(bat_size, -1, width * height)
        value = self.value(x).view(bat_size, -1, width * height)

        base_attn = torch.bmm(query, key)
        attention = F.softmax(base_attn, dim=-1)
        output = torch.bmm(value, attention.permute(0, 2, 1))

        output = output.view(bat_size, chan, width, height)
        output = self.gamma * output + x

        return output

# Generator

In [99]:
class Generator(nn.Module):
    def __init__(self, noise_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        
        self.conv2d_1 = spectral_norm(nn.ConvTranspose2d(noise_dim, 512, kernel_size=4, stride=1, padding=0, bias = False))
        self.batch_norm1 = nn.BatchNorm2d(512, momentum=momentum)
        self.relu1 = nn.LeakyReLU(relu_slope) 

        self.conv2d_2 = spectral_norm(nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias = False))
        self.batch_norm2 = nn.BatchNorm2d(256, momentum=momentum)
        self.relu2 = nn.LeakyReLU(relu_slope) 

        # the paper reccomends self attention partway through so information has already 
        # been gathered for the attention to work on
        self.attn = Self_Attn(256)

        self.conv2d_3 = spectral_norm(nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias = False))
        self.batch_norm3 = nn.BatchNorm2d(128, momentum=momentum)
        self.relu3 = nn.LeakyReLU(relu_slope) 

        self.attn2 = Self_Attn(128)

        self.conv2d_4 = spectral_norm(nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias = False))
        self.batch_norm4 = nn.BatchNorm2d(64, momentum=momentum)
        self.relu4 = nn.LeakyReLU(relu_slope) 

        # 5th block? (if so, use 1024 to start)
        #self.conv2d_5 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias = False)
        #self.batch_norm5 = nn.BatchNorm2d(64, momentum=momentum)
        #self.relu5 = nn.LeakyReLU(relu_slope) 

        self.conv2d_final = spectral_norm(nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1, bias=False))
        self.tanh = nn.Tanh()
        # tanh?        




    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.batch_norm1(x)
        x = self.relu1(x)

        x = self.conv2d_2(x)
        x = self.batch_norm2(x)
        x = self.relu2(x)

        x = self.attn(x)

        x = self.conv2d_3(x)
        x = self.batch_norm3(x)
        x = self.relu3(x)

        x = self.attn2(x)

        x = self.conv2d_4(x)
        x = self.batch_norm4(x)
        x = self.relu4(x)

        #x = self.conv2d_5(x)
        #x = self.batch_norm5(x)
        #x = self.relu5(x)

        x = self.conv2d_final(x)
        x = self.tanh(x)

        return x

# Discriminator

In [100]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()


        self.conv2d_1 = spectral_norm(nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1, bias=False))
        # no batch norm because first layer
        self.relu1 = nn.LeakyReLU(relu_slope)
        self.dropout1 = nn.Dropout(dropout)

        self.conv2d_2 = spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False))
        self.batch_norm2 = nn.BatchNorm2d(128, momentum=momentum)
        self.relu2 = nn.LeakyReLU(relu_slope) 
        self.dropout2 = nn.Dropout(dropout)

        self.attn = Self_Attn(128)

        self.conv2d_3 = spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False))
        self.batch_norm3 = nn.BatchNorm2d(256, momentum=momentum)
        self.relu3 = nn.LeakyReLU(relu_slope) 
        self.dropout3 = nn.Dropout(dropout)

        self.attn2 = Self_Attn(256)

        self.conv2d_4 = spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False))
        self.batch_norm4 = nn.BatchNorm2d(512, momentum=momentum)     
        self.relu4 = nn.LeakyReLU(relu_slope) 
        self.dropout4 = nn.Dropout(dropout)

        # 5th block?
        #self.conv2d_5 = nn.Conv2D(512, 1024, kernel_size=4, stride=2, padding=1, bias=False)
        #self.batch_norm5 = nn.BatchNorm2d(1024, momentum=momentum)      
        #self.relu5 = nn.LeakyReLU(relu_slope) 
        #self.dropout5 = nn.Dropout(dropout)

        self.conv2d_final = spectral_norm(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0, bias = False))
        # activation sigmoid
        #self.flatten = layers.Flatten()


    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.conv2d_2(x)
        x = self.batch_norm2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.attn(x)

        x = self.conv2d_3(x)
        x = self.batch_norm3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        x = self.attn2(x)

        x = self.conv2d_4(x)
        x = self.batch_norm4(x)
        x = self.relu4(x)
        x = self.dropout4(x)

        #x = self.conv2d_5(x)
        #x = self.batch_norm5(x)
        #x = self.relu5(x)
        #x = self.dropout5(x)
        
        return self.conv2d_final(x).view(x.shape[0], -1)

# Utility Functions

In [101]:
# Gradient penalty is used so that the discriminator loss won't deviate that much.
# This enforces consistent training, so that the discriminator won't learn to much or
# too little.

def gradient_penalty(discriminator, real, fake, device="cpu"):
    cur_batch_size, chan, height, width = real.shape
    beta = torch.rand((cur_batch_size, 1, 1, 1)).repeat(1, chan, height, width).to(device)
    
    # missized tensors
    if(cur_batch_size != batch_size):
        fake = fake[:cur_batch_size]

    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate discriminator scores
    mixed_scores = discriminator(interpolated_images)

    # Get gradient for the scores (graph needed here because of gradient penalty)
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="checkpoint.pth.tar"):
    print("Saving Checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # learning rate saved just in case training afterwards
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def generate_examples(gen, n=100):
    gen.eval()
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, noise_dim, 1, 1).to(device)
            img = gen(noise)
            save_image(img * 0.5 + 0.5, f'sagan_examples/img_{i}.png')

    gen.train()


# Training can happen now

In [102]:
# Will hopefully speedup since I'm using images of the same resolution for training
torch.backends.cudnn.benchmarks = True

# Custom batch loader since I am using variable batch sizes at each step
def get_loader(image_size):
    # image-readable format
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(channels)],
                [0.5 for _ in range(channels)]
            )
        ]
    )
    dataset = datasets.ImageFolder(root=datapath, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    return loader, dataset


# Training loop
def train_loop(dis, gen, loader, dataset, opt_dis, opt_gen, scaler_gen, scaler_dis, do_ws_loss=True):
    loop = tqdm(loader, leave=True)
    for batch_index, (real, _) in enumerate(loop):
        real = real.to(device)

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim, 1, 1).to(device)

        with torch.amp.autocast("cuda"):
            fake = gen(noise)
            dis_real = dis(real)
            dis_fake = dis(fake.detach())
            gp = gradient_penalty(dis, real, fake, device)
            if(do_ws_loss == True):
                # minimize mistakes, wassestein loss
                loss_dis = (
                    -(torch.mean(dis_real) - torch.mean(dis_fake))
                    + 10 * gp
                    + (0.001 * torch.mean(dis_real ** 2))
                )
            else:
                # sagan hinge loss
                real_loss = -torch.minimum(torch.zeros_like(dis_real), dis_real - 1).mean()
                fake_loss = -torch.minimum(torch.zeros_like(dis_fake), -1 - dis_fake).mean()
                loss_dis = real_loss + fake_loss

        if(do_ws_loss):
            opt_dis.zero_grad()
            scaler_dis.scale(loss_dis).backward()
            scaler_dis.step(opt_dis)
            scaler_dis.update()
        else:
            opt_dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

        # Train Generator:
        with torch.amp.autocast("cuda"):
            gen_fake = dis(fake)
            loss_gen = -torch.mean(gen_fake)

        if(do_ws_loss):
            opt_gen.zero_grad()
            scaler_gen.scale(loss_gen).backward()
            scaler_gen.step(opt_gen)
            scaler_gen.update()
        else:
            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

        # show losses in tqdm
        if(do_ws_loss==True):
            loop.set_postfix(
                gp=gp.item(),
                loss_dis=loss_dis.item()
            )
        else:
            loop.set_postfix(
                loss_dis=loss_dis.item()
            )

In [103]:
# Note: Use hinge loss if time allows


# models
gen = Generator(noise_dim, img_size, channels).to(device)
dis = Discriminator(img_size, channels).to(device)

# optimizers
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
opt_dis = optim.Adam(dis.parameters(), lr=lr, betas=(0.0, 0.99))

scaler_dis = torch.amp.GradScaler("cuda")
scaler_gen = torch.amp.GradScaler("cuda")

# load model if retraining
if(load_model):
    load_checkpoint(gen_save_path, gen, opt_gen, lr)
    load_checkpoint(dis_save_path, dis, opt_dis, lr)

if(load_model == False):
    gen.train()
    dis.train()

    loader, dataset = get_loader(img_size)

    for epoch in range(total_epochs):
        print(f"Epoch [{epoch+1}/{total_epochs}]")
        train_loop(dis, gen, loader, dataset, opt_dis, opt_gen, scaler_gen, scaler_dis, ws_loss)

        if(save_model==True):
            save_checkpoint(gen, opt_gen, filename=gen_save_path)
            save_checkpoint(dis, opt_dis, filename=dis_save_path)
            generate_examples(gen, 20)

Epoch [1/100]


100%|██████████| 467/467 [00:44<00:00, 10.44it/s, loss_dis=1.88]


Saving Checkpoint
Saving Checkpoint
Epoch [2/100]


100%|██████████| 467/467 [00:44<00:00, 10.45it/s, loss_dis=1.94]


Saving Checkpoint
Saving Checkpoint
Epoch [3/100]


100%|██████████| 467/467 [00:45<00:00, 10.19it/s, loss_dis=1.98]


Saving Checkpoint
Saving Checkpoint
Epoch [4/100]


100%|██████████| 467/467 [00:45<00:00, 10.24it/s, loss_dis=1.86]


Saving Checkpoint
Saving Checkpoint
Epoch [5/100]


100%|██████████| 467/467 [00:45<00:00, 10.25it/s, loss_dis=1.84]


Saving Checkpoint
Saving Checkpoint
Epoch [6/100]


100%|██████████| 467/467 [00:45<00:00, 10.31it/s, loss_dis=1.8] 


Saving Checkpoint
Saving Checkpoint
Epoch [7/100]


100%|██████████| 467/467 [00:45<00:00, 10.34it/s, loss_dis=1.95]


Saving Checkpoint
Saving Checkpoint
Epoch [8/100]


100%|██████████| 467/467 [00:46<00:00, 10.02it/s, loss_dis=1.78]


Saving Checkpoint
Saving Checkpoint
Epoch [9/100]


100%|██████████| 467/467 [00:46<00:00, 10.06it/s, loss_dis=1.86]


Saving Checkpoint
Saving Checkpoint
Epoch [10/100]


100%|██████████| 467/467 [00:45<00:00, 10.22it/s, loss_dis=1.94]


Saving Checkpoint
Saving Checkpoint
Epoch [11/100]


100%|██████████| 467/467 [00:45<00:00, 10.34it/s, loss_dis=1.84]


Saving Checkpoint
Saving Checkpoint
Epoch [12/100]


100%|██████████| 467/467 [00:45<00:00, 10.18it/s, loss_dis=2.06]


Saving Checkpoint
Saving Checkpoint
Epoch [13/100]


100%|██████████| 467/467 [00:44<00:00, 10.53it/s, loss_dis=1.5]  


Saving Checkpoint
Saving Checkpoint
Epoch [14/100]


100%|██████████| 467/467 [00:44<00:00, 10.41it/s, loss_dis=1.46] 


Saving Checkpoint
Saving Checkpoint
Epoch [15/100]


100%|██████████| 467/467 [00:46<00:00, 10.06it/s, loss_dis=1.16] 


Saving Checkpoint
Saving Checkpoint
Epoch [16/100]


100%|██████████| 467/467 [00:47<00:00,  9.78it/s, loss_dis=1.26] 


Saving Checkpoint
Saving Checkpoint
Epoch [17/100]


100%|██████████| 467/467 [00:48<00:00,  9.72it/s, loss_dis=1.89] 


Saving Checkpoint
Saving Checkpoint
Epoch [18/100]


100%|██████████| 467/467 [00:46<00:00, 10.12it/s, loss_dis=1.31] 


Saving Checkpoint
Saving Checkpoint
Epoch [19/100]


100%|██████████| 467/467 [00:46<00:00, 10.07it/s, loss_dis=1.63] 


Saving Checkpoint
Saving Checkpoint
Epoch [20/100]


100%|██████████| 467/467 [00:47<00:00,  9.89it/s, loss_dis=1.18] 


Saving Checkpoint
Saving Checkpoint
Epoch [21/100]


100%|██████████| 467/467 [00:45<00:00, 10.25it/s, loss_dis=1.44] 


Saving Checkpoint
Saving Checkpoint
Epoch [22/100]


100%|██████████| 467/467 [00:45<00:00, 10.18it/s, loss_dis=0.711]


Saving Checkpoint
Saving Checkpoint
Epoch [23/100]


100%|██████████| 467/467 [00:46<00:00, 10.02it/s, loss_dis=1.56] 


Saving Checkpoint
Saving Checkpoint
Epoch [24/100]


100%|██████████| 467/467 [00:48<00:00,  9.59it/s, loss_dis=1.68] 


Saving Checkpoint
Saving Checkpoint
Epoch [25/100]


100%|██████████| 467/467 [00:48<00:00,  9.73it/s, loss_dis=1.17] 


Saving Checkpoint
Saving Checkpoint
Epoch [26/100]


100%|██████████| 467/467 [00:46<00:00, 10.14it/s, loss_dis=0.884]


Saving Checkpoint
Saving Checkpoint
Epoch [27/100]


100%|██████████| 467/467 [00:46<00:00, 10.02it/s, loss_dis=1.34] 


Saving Checkpoint
Saving Checkpoint
Epoch [28/100]


100%|██████████| 467/467 [00:46<00:00, 10.10it/s, loss_dis=0.852]


Saving Checkpoint
Saving Checkpoint
Epoch [29/100]


100%|██████████| 467/467 [00:47<00:00,  9.79it/s, loss_dis=2.02] 


Saving Checkpoint
Saving Checkpoint
Epoch [30/100]


100%|██████████| 467/467 [00:47<00:00,  9.87it/s, loss_dis=1.04] 


Saving Checkpoint
Saving Checkpoint
Epoch [31/100]


100%|██████████| 467/467 [00:46<00:00,  9.99it/s, loss_dis=1.72] 


Saving Checkpoint
Saving Checkpoint
Epoch [32/100]


100%|██████████| 467/467 [00:45<00:00, 10.24it/s, loss_dis=1.02] 


Saving Checkpoint
Saving Checkpoint
Epoch [33/100]


100%|██████████| 467/467 [00:47<00:00,  9.86it/s, loss_dis=0.782]


Saving Checkpoint
Saving Checkpoint
Epoch [34/100]


100%|██████████| 467/467 [00:47<00:00,  9.86it/s, loss_dis=1.6]  


Saving Checkpoint
Saving Checkpoint
Epoch [35/100]


100%|██████████| 467/467 [00:46<00:00, 10.15it/s, loss_dis=0.793]


Saving Checkpoint
Saving Checkpoint
Epoch [36/100]


100%|██████████| 467/467 [00:44<00:00, 10.42it/s, loss_dis=0.744]


Saving Checkpoint
Saving Checkpoint
Epoch [37/100]


100%|██████████| 467/467 [00:45<00:00, 10.36it/s, loss_dis=0.478]


Saving Checkpoint
Saving Checkpoint
Epoch [38/100]


100%|██████████| 467/467 [00:45<00:00, 10.36it/s, loss_dis=1.39] 


Saving Checkpoint
Saving Checkpoint
Epoch [39/100]


100%|██████████| 467/467 [00:46<00:00, 10.13it/s, loss_dis=0.874]


Saving Checkpoint
Saving Checkpoint
Epoch [40/100]


100%|██████████| 467/467 [00:45<00:00, 10.34it/s, loss_dis=1.68] 


Saving Checkpoint
Saving Checkpoint
Epoch [41/100]


100%|██████████| 467/467 [00:44<00:00, 10.50it/s, loss_dis=1.14] 


Saving Checkpoint
Saving Checkpoint
Epoch [42/100]


100%|██████████| 467/467 [00:43<00:00, 10.73it/s, loss_dis=1.05] 


Saving Checkpoint
Saving Checkpoint
Epoch [43/100]


100%|██████████| 467/467 [00:45<00:00, 10.33it/s, loss_dis=0.603]


Saving Checkpoint
Saving Checkpoint
Epoch [44/100]


100%|██████████| 467/467 [00:45<00:00, 10.23it/s, loss_dis=1.85] 


Saving Checkpoint
Saving Checkpoint
Epoch [45/100]


100%|██████████| 467/467 [00:45<00:00, 10.36it/s, loss_dis=0.581]


Saving Checkpoint
Saving Checkpoint
Epoch [46/100]


100%|██████████| 467/467 [00:45<00:00, 10.32it/s, loss_dis=1.42] 


Saving Checkpoint
Saving Checkpoint
Epoch [47/100]


100%|██████████| 467/467 [00:46<00:00,  9.95it/s, loss_dis=0.542]


Saving Checkpoint
Saving Checkpoint
Epoch [48/100]


100%|██████████| 467/467 [00:46<00:00,  9.95it/s, loss_dis=1.58] 


Saving Checkpoint
Saving Checkpoint
Epoch [49/100]


100%|██████████| 467/467 [00:45<00:00, 10.16it/s, loss_dis=1.16] 


Saving Checkpoint
Saving Checkpoint
Epoch [50/100]


100%|██████████| 467/467 [00:45<00:00, 10.20it/s, loss_dis=0.867]


Saving Checkpoint
Saving Checkpoint
Epoch [51/100]


100%|██████████| 467/467 [00:45<00:00, 10.35it/s, loss_dis=0.706]


Saving Checkpoint
Saving Checkpoint
Epoch [52/100]


100%|██████████| 467/467 [00:45<00:00, 10.35it/s, loss_dis=1.08] 


Saving Checkpoint
Saving Checkpoint
Epoch [53/100]


100%|██████████| 467/467 [00:46<00:00, 10.09it/s, loss_dis=1.39] 


Saving Checkpoint
Saving Checkpoint
Epoch [54/100]


100%|██████████| 467/467 [00:46<00:00,  9.96it/s, loss_dis=0.782]


Saving Checkpoint
Saving Checkpoint
Epoch [55/100]


100%|██████████| 467/467 [00:46<00:00, 10.01it/s, loss_dis=1.28] 


Saving Checkpoint
Saving Checkpoint
Epoch [56/100]


100%|██████████| 467/467 [00:45<00:00, 10.30it/s, loss_dis=1.75] 


Saving Checkpoint
Saving Checkpoint
Epoch [57/100]


100%|██████████| 467/467 [00:45<00:00, 10.18it/s, loss_dis=0.746]


Saving Checkpoint
Saving Checkpoint
Epoch [58/100]


100%|██████████| 467/467 [00:47<00:00,  9.85it/s, loss_dis=0.835]


Saving Checkpoint
Saving Checkpoint
Epoch [59/100]


100%|██████████| 467/467 [00:48<00:00,  9.68it/s, loss_dis=2.04] 


Saving Checkpoint
Saving Checkpoint
Epoch [60/100]


100%|██████████| 467/467 [00:45<00:00, 10.17it/s, loss_dis=1.19] 


Saving Checkpoint
Saving Checkpoint
Epoch [61/100]


100%|██████████| 467/467 [00:46<00:00, 10.09it/s, loss_dis=1.46] 


Saving Checkpoint
Saving Checkpoint
Epoch [62/100]


100%|██████████| 467/467 [00:48<00:00,  9.69it/s, loss_dis=0.625]


Saving Checkpoint
Saving Checkpoint
Epoch [63/100]


100%|██████████| 467/467 [00:48<00:00,  9.68it/s, loss_dis=0.734]


Saving Checkpoint
Saving Checkpoint
Epoch [64/100]


100%|██████████| 467/467 [00:45<00:00, 10.26it/s, loss_dis=1.12] 


Saving Checkpoint
Saving Checkpoint
Epoch [65/100]


100%|██████████| 467/467 [00:45<00:00, 10.21it/s, loss_dis=1.24] 


Saving Checkpoint
Saving Checkpoint
Epoch [66/100]


100%|██████████| 467/467 [00:45<00:00, 10.33it/s, loss_dis=1.97] 


Saving Checkpoint
Saving Checkpoint
Epoch [67/100]


100%|██████████| 467/467 [00:45<00:00, 10.30it/s, loss_dis=1.03] 


Saving Checkpoint
Saving Checkpoint
Epoch [68/100]


100%|██████████| 467/467 [00:45<00:00, 10.38it/s, loss_dis=0.914]


Saving Checkpoint
Saving Checkpoint
Epoch [69/100]


100%|██████████| 467/467 [00:47<00:00,  9.86it/s, loss_dis=1.09] 


Saving Checkpoint
Saving Checkpoint
Epoch [70/100]


100%|██████████| 467/467 [00:49<00:00,  9.51it/s, loss_dis=0.536]


Saving Checkpoint
Saving Checkpoint
Epoch [71/100]


100%|██████████| 467/467 [00:47<00:00,  9.74it/s, loss_dis=0.898]


Saving Checkpoint
Saving Checkpoint
Epoch [72/100]


100%|██████████| 467/467 [00:49<00:00,  9.51it/s, loss_dis=0.882]


Saving Checkpoint
Saving Checkpoint
Epoch [73/100]


100%|██████████| 467/467 [00:48<00:00,  9.59it/s, loss_dis=1.21] 


Saving Checkpoint
Saving Checkpoint
Epoch [74/100]


100%|██████████| 467/467 [00:47<00:00,  9.88it/s, loss_dis=1.77] 


Saving Checkpoint
Saving Checkpoint
Epoch [75/100]


100%|██████████| 467/467 [00:46<00:00,  9.99it/s, loss_dis=0.846]


Saving Checkpoint
Saving Checkpoint
Epoch [76/100]


100%|██████████| 467/467 [00:47<00:00,  9.87it/s, loss_dis=1.04] 


Saving Checkpoint
Saving Checkpoint
Epoch [77/100]


100%|██████████| 467/467 [00:48<00:00,  9.72it/s, loss_dis=1.56]  


Saving Checkpoint
Saving Checkpoint
Epoch [78/100]


100%|██████████| 467/467 [00:48<00:00,  9.71it/s, loss_dis=0.815]


Saving Checkpoint
Saving Checkpoint
Epoch [79/100]


100%|██████████| 467/467 [00:50<00:00,  9.32it/s, loss_dis=1.12] 


Saving Checkpoint
Saving Checkpoint
Epoch [80/100]


100%|██████████| 467/467 [00:50<00:00,  9.27it/s, loss_dis=0.64] 


Saving Checkpoint
Saving Checkpoint
Epoch [81/100]


100%|██████████| 467/467 [00:49<00:00,  9.49it/s, loss_dis=0.151]


Saving Checkpoint
Saving Checkpoint
Epoch [82/100]


100%|██████████| 467/467 [00:48<00:00,  9.68it/s, loss_dis=0.693]


Saving Checkpoint
Saving Checkpoint
Epoch [83/100]


100%|██████████| 467/467 [00:47<00:00,  9.79it/s, loss_dis=1.55] 


Saving Checkpoint
Saving Checkpoint
Epoch [84/100]


100%|██████████| 467/467 [00:47<00:00,  9.83it/s, loss_dis=0.54] 


Saving Checkpoint
Saving Checkpoint
Epoch [85/100]


100%|██████████| 467/467 [00:46<00:00,  9.97it/s, loss_dis=0.828]


Saving Checkpoint
Saving Checkpoint
Epoch [86/100]


100%|██████████| 467/467 [00:47<00:00,  9.83it/s, loss_dis=0.211] 


Saving Checkpoint
Saving Checkpoint
Epoch [87/100]


100%|██████████| 467/467 [00:48<00:00,  9.73it/s, loss_dis=1.52] 


Saving Checkpoint
Saving Checkpoint
Epoch [88/100]


100%|██████████| 467/467 [00:48<00:00,  9.61it/s, loss_dis=1.03] 


Saving Checkpoint
Saving Checkpoint
Epoch [89/100]


100%|██████████| 467/467 [00:46<00:00, 10.01it/s, loss_dis=1.02] 


Saving Checkpoint
Saving Checkpoint
Epoch [90/100]


100%|██████████| 467/467 [00:47<00:00,  9.89it/s, loss_dis=2.32]  


Saving Checkpoint
Saving Checkpoint
Epoch [91/100]


100%|██████████| 467/467 [00:47<00:00,  9.77it/s, loss_dis=0.834] 


Saving Checkpoint
Saving Checkpoint
Epoch [92/100]


100%|██████████| 467/467 [00:48<00:00,  9.68it/s, loss_dis=0.885] 


Saving Checkpoint
Saving Checkpoint
Epoch [93/100]


100%|██████████| 467/467 [00:47<00:00,  9.93it/s, loss_dis=0.651] 


Saving Checkpoint
Saving Checkpoint
Epoch [94/100]


100%|██████████| 467/467 [00:48<00:00,  9.69it/s, loss_dis=1.64]  


Saving Checkpoint
Saving Checkpoint
Epoch [95/100]


100%|██████████| 467/467 [00:46<00:00, 10.10it/s, loss_dis=0.497] 


Saving Checkpoint
Saving Checkpoint
Epoch [96/100]


100%|██████████| 467/467 [00:47<00:00,  9.79it/s, loss_dis=0.234] 


Saving Checkpoint
Saving Checkpoint
Epoch [97/100]


100%|██████████| 467/467 [00:47<00:00,  9.88it/s, loss_dis=1.8]   


Saving Checkpoint
Saving Checkpoint
Epoch [98/100]


100%|██████████| 467/467 [00:47<00:00,  9.83it/s, loss_dis=0.307] 


Saving Checkpoint
Saving Checkpoint
Epoch [99/100]


100%|██████████| 467/467 [00:46<00:00, 10.08it/s, loss_dis=0.46]  


Saving Checkpoint
Saving Checkpoint
Epoch [100/100]


100%|██████████| 467/467 [00:46<00:00, 10.03it/s, loss_dis=0.809] 


Saving Checkpoint
Saving Checkpoint


In [104]:
# Display for presentation purposes
def display(
    images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
):
    """
    Displays n random images from each one of the supplied arrays.
    """

    plt.figure(figsize=size)
    for i in range(n):
        _ = plt.subplot(1, n, i + 1)
        plt.imshow((images[i].astype(as_type) + 1) / 2, cmap=cmap)
        plt.axis("off")

    if save_to:
        plt.savefig(save_to)
        print(f"\nSaved to {save_to}")

    plt.show()


def preprocess(img):
    img = (img.astype(np.float32) - 127.5)/ 127.5
    return img

In [105]:
if(load_model):
    # Evaluate
    generate_examples(gen, 10)

    # get directory
    cwd = os.getcwd()
    path = cwd + "/sagan_examples"
    images = []

    # get all images
    for image in sorted(os.listdir(path)):
        image_path = os.path.join(path, image)
        img = cv2.imread(image_path)
        if(image is not None):
            images.append(img)
    
    train = list(map(preprocess, images))

    display(train[:10])