In [None]:
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import *
from models import *
from PIL import Image
!mkdir generated_images


In [None]:
lr_gen = 0.0006*2 #Learning rate for generator
lr_dis = 0.0003*2 #Learning rate for discriminator
latent_dim = 256 #Latent dimension
epoch = 200 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout

# architecture details by authors
image_size = 128 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 384 #Embedding dimension
phi = 2 #
beta1 = 0.9 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation


In [None]:
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"


generator = (
    Generator(
        depth1=6,
        depth2=4,
        depth3=1,
        initial_size=32,
        dim=384,
        heads=3,
        mlp_ratio=4,
        drop_rate=0.5,
        latent_dim=256,
    )
    .to(device)
    .apply(inits_weight)
)

discriminator = (
    Discriminator(
        diff_aug=diff_aug,
        image_size=128,
        patch_size=8,
        input_channel=1,
        num_classes=1,
        dim=384,
        depth=3,
        heads=4,
        mlp_ratio=2,
        drop_rate=0.5,
    )
    .to(device)
    .apply(inits_weight)
)

optim_gen = optim.AdamW(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr=lr_gen,
    betas=(beta1, beta2),
)
optim_dis = optim.AdamW(
    filter(lambda p: p.requires_grad, discriminator.parameters()),
    lr=lr_dis,
    betas=(beta1, beta2),
)

gen_scheduler = optim.lr_scheduler.LinearLR(optim_gen, start_factor = 1, end_factor = .5, total_iters = 2000*200)
dis_scheduler = optim.lr_scheduler.LinearLR(optim_gen, start_factor = 1, end_factor = .5, total_iters = 2000*200)

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(
        real_samples.get_device()
    )
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(
        True
    )
    d_interpolates = D(interpolates)
    fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(
        real_samples.get_device()
    )
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.contiguous().view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
    return gradient_penalty

In [None]:
img_size = 128
transform = transforms.Compose(
        [
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            transforms.Grayscale(1)
        ]
    )

    # train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_set = torchvision.datasets.ImageFolder(
        root=r"C:\Users\aashr\Desktop\research\testing_grounds\images",
        transform=transform,
    )
train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=2, shuffle=True, drop_last = True
    )

def train(
    noise,
    generator,
    discriminator,
    optim_gen,
    optim_dis,
    epoch,
    gen_scheduler,
    dis_scheduler,
    latent_dim=latent_dim,
    device="cuda:0",
):
    generator = generator.train()
    discriminator = discriminator.train()

    for index, (img, _) in enumerate(train_loader):

        real_imgs = img.type(torch.cuda.FloatTensor)
        noise = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (img.shape[0], latent_dim))
        )

        optim_dis.zero_grad()
        real_valid = discriminator(real_imgs)
        fake_imgs = generator(noise).detach()
        fake_valid = discriminator(fake_imgs)

        gradient_penalty = compute_gradient_penalty(
            discriminator, real_imgs, fake_imgs.detach(), phi
        )
        disc_loss = (
            -torch.mean(real_valid)
            + torch.mean(fake_valid)
            + gradient_penalty * 10 / (phi**2)
        )
        
        disc_loss.backward()
        optim_dis.step()
        optim_gen.zero_grad()
        dis_scheduler.step()
        d_lr = dis_scheduler.get_last_lr()

        gener_noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (img.shape[0], latent_dim)))
        generated_imgs = generator(gener_noise)
        fake_valid = discriminator(generated_imgs)
        gener_loss = -torch.mean(fake_valid).to(device)
        gener_loss.backward()
        optim_gen.step()
        gen_scheduler.step()
        g_lr = gen_scheduler.get_last_lr()
        print(f"\r[Epoch {epoch+1}] [Batch {index+1}/{len(train_loader)}] [D loss: {disc_loss.item()}] [G loss: {gener_loss.item()}] [D lr: {d_lr[0]}] [G lr: {g_lr[0]}] {img.shape[0], generated_imgs.shape}                \r",end = "\r")
        del fake_valid, real_valid, fake_imgs, real_imgs, disc_loss, gener_loss, gradient_penalty
        
    save_image(
        generated_imgs,
        f"generated_images/generated_img_{epoch}_{128}.png",
    )
    display(
        Image.open(
            f"generated_images/generated_img_{epoch}_{128}.png"
        )
    )
    del generated_imgs

In [None]:

epoch = 2000
for epoch in range(epoch):
    train(
        noise,
        generator,
        discriminator,
        optim_gen,
        optim_dis,
        epoch,
        gen_scheduler,
        dis_scheduler,
        latent_dim=latent_dim,
    )