In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator
from math import log2
from tqdm import tqdm
import cv2
import numpy as np
import os
from scipy.stats import truncnorm

torch.backends.cudnn.benchmarks = True

In [9]:
starting_img_size = 4
dataset = "celeb_hq_dataset"
checkpoint_generator = "generator.pth"
checkpoint_discriminator = "discriminator.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-3
batch_size = [4, 4, 4, 4, 4, 4, 4, 2, 1]
img_size = 512
img_channels = 3
z_dim = 256
in_channels = 256
l_gp = 10
num_steps = int(log2(img_size /4)) + 1
progressive_epochs = [10] * len(batch_size)
noise = torch.randn(8, z_dim, 1, 1).to(device)
workers = 4

Extra Functions

In [10]:
def gradient_penalty(discriminator, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)
    mixed_scores = discriminator(interpolated_images, alpha, train_step)
    gradient = torch.autograd.grad(inputs=interpolated_images, outputs=mixed_scores, grad_outputs=torch.ones_like(mixed_scores), create_graph=True, retain_graph=True,)[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {"state_dict": model.state_dict(),"optimizer": optimizer.state_dict(),}
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def plot_to_tensorboard(writer, disc_loss, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar("Loss discriminator", disc_loss, global_step=tensorboard_step)

    with torch.no_grad():
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

Creating a loader

In [11]:
def create_loader(image_size):
    transform = transforms.Compose([transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize( [0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)], ), ] )
    batch_sizes = batch_size[int(log2(image_size / 4))]
    data = datasets.ImageFolder(root=dataset, transform=transform)
    loader = DataLoader(data, batch_size=batch_sizes, shuffle=True, num_workers=workers, pin_memory=True,)
    return loader, data

Training

In [12]:
def train(discriminator, gen, loader, dataset, step, alpha, discriminator_optimizer, opt_gen, tensorboard_step, writer, generator_Scaler, discriminator_Scaler,):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            discriminator_real = discriminator(real, alpha, step)
            discriminator_fake = discriminator(fake.detach(), alpha, step)
            gp = gradient_penalty(discriminator, real, fake, alpha, step, device=device)
            disc_loss = (-(torch.mean(discriminator_real) - torch.mean(discriminator_fake)) + l_gp * gp + (0.001 * torch.mean(discriminator_real ** 2)) )

        discriminator_optimizer.zero_grad()
        discriminator_Scaler.scale(disc_loss).backward()
        discriminator_Scaler.step(discriminator_optimizer)
        discriminator_Scaler.update()

        with torch.cuda.amp.autocast():
            gen_fake = discriminator(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        generator_Scaler.scale(loss_gen).backward()
        generator_Scaler.step(opt_gen)
        generator_Scaler.update()

        alpha += cur_batch_size / (
            (progressive_epochs[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(noise, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                disc_loss.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            disc_loss=disc_loss.item(),
        )

    return tensorboard_step, alpha

In [13]:
def main():
    gen = Generator(z_dim, in_channels, img_channels=img_channels).to(device)
    discriminator = Discriminator(z_dim, in_channels, img_channels=img_channels).to(device)
    generator_optimizer = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.0, 0.99))
    discriminator_scaler = torch.cuda.amp.GradScaler()
    generator_scaler = torch.cuda.amp.GradScaler()
    writer = SummaryWriter(f"logs/true_GAN")
    gen.train()
    discriminator.train()

    tensorboard_step = 0
    step = int(log2(starting_img_size / 4))
    for num_epochs in progressive_epochs[step:]:
        alpha = 0
        loader, dataset = create_loader(4 * 2 ** step)
        print(f"Current image size: {4 * 2 ** step}")
        print("----------------------------------------------------------------------------------")
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train(discriminator,gen, loader, dataset, step, alpha, discriminator_optimizer, generator_optimizer, tensorboard_step, writer, generator_scaler, discriminator_scaler,)
            save_checkpoint(gen, generator_optimizer, filename=checkpoint_generator)
            save_checkpoint(discriminator, discriminator_optimizer, filename=checkpoint_discriminator)

        step += 1  

In [14]:
main()

Current image size: 4
----------------------------------------------------------------------------------
Epoch [1/10]


100%|██████████| 4486/4486 [01:39<00:00, 45.04it/s, disc_loss=0.251, gp=0.00156]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10]


100%|██████████| 4486/4486 [01:37<00:00, 45.97it/s, disc_loss=0.422, gp=0.00296]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/10]


100%|██████████| 4486/4486 [01:39<00:00, 45.30it/s, disc_loss=-.12, gp=0.00756]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/10]


100%|██████████| 4486/4486 [01:37<00:00, 45.85it/s, disc_loss=0.219, gp=0.00701]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/10]


100%|██████████| 4486/4486 [01:36<00:00, 46.25it/s, disc_loss=-.199, gp=0.0113]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/10]


100%|██████████| 4486/4486 [01:37<00:00, 45.78it/s, disc_loss=0.271, gp=0.00539]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [7/10]


100%|██████████| 4486/4486 [01:38<00:00, 45.47it/s, disc_loss=-.496, gp=0.0225]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [8/10]


100%|██████████| 4486/4486 [01:38<00:00, 45.67it/s, disc_loss=0.203, gp=0.00322]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10]


100%|██████████| 4486/4486 [01:36<00:00, 46.40it/s, disc_loss=-.000205, gp=0.0164] 


=> Saving checkpoint
=> Saving checkpoint
Epoch [10/10]


100%|██████████| 4486/4486 [01:37<00:00, 46.09it/s, disc_loss=0.261, gp=0.0182]    


=> Saving checkpoint
=> Saving checkpoint
Current image size: 8
----------------------------------------------------------------------------------
Epoch [1/10]


100%|██████████| 4486/4486 [02:54<00:00, 25.73it/s, disc_loss=0.638, gp=0.00525]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10]


100%|██████████| 4486/4486 [02:53<00:00, 25.89it/s, disc_loss=-.499, gp=0.00598]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/10]


100%|██████████| 4486/4486 [02:52<00:00, 25.96it/s, disc_loss=0.706, gp=0.0108]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/10]


100%|██████████| 4486/4486 [02:53<00:00, 25.93it/s, disc_loss=-.204, gp=0.0116]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/10]


100%|██████████| 4486/4486 [02:58<00:00, 25.19it/s, disc_loss=0.731, gp=0.0186]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/10]


100%|██████████| 4486/4486 [02:55<00:00, 25.61it/s, disc_loss=0.357, gp=0.00563]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [7/10]


100%|██████████| 4486/4486 [02:56<00:00, 25.45it/s, disc_loss=-.133, gp=0.0575]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [8/10]


100%|██████████| 4486/4486 [02:54<00:00, 25.73it/s, disc_loss=0.367, gp=0.00735]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10]


100%|██████████| 4486/4486 [02:56<00:00, 25.47it/s, disc_loss=-.434, gp=0.00436]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [10/10]


100%|██████████| 4486/4486 [02:54<00:00, 25.68it/s, disc_loss=1.11, gp=0.00763]    


=> Saving checkpoint
=> Saving checkpoint
Current image size: 16
----------------------------------------------------------------------------------
Epoch [1/10]


100%|██████████| 4486/4486 [03:37<00:00, 20.60it/s, disc_loss=-.457, gp=0.0358]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10]


100%|██████████| 4486/4486 [03:42<00:00, 20.18it/s, disc_loss=0.919, gp=0.021]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/10]


100%|██████████| 4486/4486 [03:44<00:00, 20.02it/s, disc_loss=-1.21, gp=0.0102]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/10]


100%|██████████| 4486/4486 [03:41<00:00, 20.27it/s, disc_loss=0.995, gp=0.00622]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/10]


100%|██████████| 4486/4486 [03:40<00:00, 20.35it/s, disc_loss=0.479, gp=0.0156]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/10]


100%|██████████| 4486/4486 [03:43<00:00, 20.08it/s, disc_loss=-.805, gp=0.00415]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [7/10]


100%|██████████| 4486/4486 [03:45<00:00, 19.91it/s, disc_loss=0.237, gp=0.0126]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [8/10]


100%|██████████| 4486/4486 [03:42<00:00, 20.21it/s, disc_loss=-.0692, gp=5.2e-5]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10]


100%|██████████| 4486/4486 [03:38<00:00, 20.49it/s, disc_loss=0.689, gp=0.00524]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [10/10]


100%|██████████| 4486/4486 [03:41<00:00, 20.25it/s, disc_loss=0.4, gp=0.00125]     


=> Saving checkpoint
=> Saving checkpoint
Current image size: 32
----------------------------------------------------------------------------------
Epoch [1/10]


100%|██████████| 4486/4486 [04:41<00:00, 15.92it/s, disc_loss=-1.28, gp=0.0155]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10]


100%|██████████| 4486/4486 [04:36<00:00, 16.22it/s, disc_loss=1.2, gp=0.0411]       


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/10]


100%|██████████| 4486/4486 [04:36<00:00, 16.24it/s, disc_loss=-.785, gp=0.00389]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/10]


100%|██████████| 4486/4486 [04:37<00:00, 16.18it/s, disc_loss=0.241, gp=0.00898]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/10]


100%|██████████| 4486/4486 [04:42<00:00, 15.89it/s, disc_loss=-1.02, gp=0.00709]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/10]


100%|██████████| 4486/4486 [04:39<00:00, 16.05it/s, disc_loss=-3.32, gp=0.00193]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [7/10]


100%|██████████| 4486/4486 [04:37<00:00, 16.16it/s, disc_loss=0.513, gp=0.00181]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [8/10]


100%|██████████| 4486/4486 [04:39<00:00, 16.06it/s, disc_loss=-1.7, gp=0.0495]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/10]


100%|██████████| 4486/4486 [04:43<00:00, 15.82it/s, disc_loss=-.0406, gp=0.00565]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [10/10]


100%|██████████| 4486/4486 [04:40<00:00, 15.98it/s, disc_loss=0.298, gp=0.00548]    


=> Saving checkpoint
=> Saving checkpoint
Current image size: 64
----------------------------------------------------------------------------------
Epoch [1/10]


100%|██████████| 4486/4486 [07:49<00:00,  9.56it/s, disc_loss=-2.02, gp=0.0172]     


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10]


100%|██████████| 4486/4486 [08:02<00:00,  9.30it/s, disc_loss=-2.3, gp=0.00285]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/10]


 66%|██████▌   | 2942/4486 [05:26<02:51,  9.01it/s, disc_loss=-.545, gp=0.00362]  


MemoryError: Caught MemoryError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "c:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\_utils\worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "c:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "c:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "c:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\folder.py", line 230, in __getitem__
    sample = self.loader(path)
  File "c:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\folder.py", line 269, in default_loader
    return pil_loader(path)
  File "c:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\folder.py", line 249, in pil_loader
    return img.convert("RGB")
  File "c:\ProgramData\Anaconda3\lib\site-packages\PIL\Image.py", line 901, in convert
    return self.copy()
  File "c:\ProgramData\Anaconda3\lib\site-packages\PIL\Image.py", line 1126, in copy
    return self._new(self.im.copy())
MemoryError
