In [None]:
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.optim as optim
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import *
from models import *
from PIL import Image
!mkdir generated_images


In [None]:
# training hyperparameters given by code author

lr_gen = 0.0006 #Learning rate for generator
lr_dis = 0.0003 #Learning rate for discriminator
latent_dim = 256 #Latent dimension
epoch = 200 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout
n_critic = 5 #
max_iter = 500000
img_name = "img_name"
lr_decay = True

# architecture details by authors
image_size = 64 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 384 #Embedding dimension 
optimizer = 'Adam' #Optimizer
loss = "wgangp_eps" #Loss function
phi = 1 #
beta1 = 0 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation


In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

generator= Generator(depth1=6, depth2=4, depth3=1, initial_size=16, dim=384*2, heads=2, mlp_ratio=4, drop_rate=0.5,latent_dim=256).to(device).apply(inits_weight)

discriminator = Discriminator(diff_aug = diff_aug, image_size=64, patch_size=16, input_channel=1, num_classes=1,
                 dim=384, depth=4, heads=4, mlp_ratio=2,
                 drop_rate=0.5).to(device).apply(inits_weight)

(1)

In [None]:
if optimizer == 'Adam':
    optim_gen = optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_gen, betas=(beta1, beta2))

    optim_dis = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),lr=lr_dis, betas=(beta1, beta2))

gen_scheduler = LinearLrDecay(optim_gen, lr_gen, 0.0, 0, max_iter * n_critic)
dis_scheduler = LinearLrDecay(optim_dis, lr_dis, 0.0, 0, max_iter * n_critic)

#RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)

print("optimizer:",optimizer)


In [None]:
a = torch.cuda.FloatTensor(np.random.normal(0, 1, (1, latent_dim)))
print(a.shape)
generator(a).shape

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(
        real_samples.get_device())
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha)
                    * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(
        real_samples.get_device())
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.contiguous().view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
    return gradient_penalty


In [None]:
def train(
    noise,
    generator,
    discriminator,
    optim_gen,
    optim_dis,
    epoch,
    schedulers,
    img_size=64,
    latent_dim=latent_dim,
    n_critic=n_critic,
    device="cuda:0",
):
    generator = generator.train()
    discriminator = discriminator.train()

    transform = transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.Grayscale(1)
        ]
    )

    # train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_set = torchvision.datasets.ImageFolder(
        root=r"C:\Users\aashr\Desktop\research\testing_grounds\images",
        transform=transform,
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=8, shuffle=True
    )

    for index, (img, _) in enumerate(train_loader):
        real_imgs = img.type(torch.cuda.FloatTensor)

        noise = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (img.shape[0], latent_dim))
        )

        optim_dis.zero_grad()
        real_valid = discriminator(real_imgs)
        fake_imgs = generator(noise).detach()
        fake_valid = discriminator(fake_imgs)

        gradient_penalty = compute_gradient_penalty(
            discriminator, real_imgs, fake_imgs.detach(), phi
        )
        loss_dis = (
            -torch.mean(real_valid)
            + torch.mean(fake_valid)
            + gradient_penalty * 10 / (phi**2)
        )
        loss_dis.backward()
        optim_dis.step()

        optim_gen.zero_grad()
        if schedulers:
            gen_scheduler, dis_scheduler = schedulers
            g_lr = gen_scheduler.step(1)
            d_lr = dis_scheduler.step(1)

        gener_noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (8, latent_dim)))
        generated_imgs = generator(gener_noise)
        fake_valid = discriminator(generated_imgs)
        gener_loss = -torch.mean(fake_valid).to(device)
        gener_loss.backward()
        optim_gen.step()

    sample_imgs = generated_imgs
    img_grid = make_grid(sample_imgs, normalize=True, scale_each=True)
    save_image(
        sample_imgs,
        f"generated_images/generated_img_{epoch}_{index % len(train_loader)}.jpg",
        normalize=True,
        scale_each=True,
    )
    display(
        Image.open(
            f"generated_images/generated_img_{epoch}_{index % len(train_loader)}.jpg"
        )
    )
    tqdm.write(
        "[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
        % (
            epoch + 1,
            index % len(train_loader),
            len(train_loader),
            loss_dis.item(),
            gener_loss.item(),
        )
    )

In [None]:
best = 1e4
epoch = 2000
for epoch in range(epoch):
    lr_schedulers = (gen_scheduler, dis_scheduler) if lr_decay else None

    train(
        noise,
        generator,
        discriminator,
        optim_gen,
        optim_dis,
        epoch,
        lr_schedulers,
        img_size=64,
        latent_dim=latent_dim,
        n_critic=n_critic,
    )