In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
""" Training of ProGAN using WGAN-GP loss"""

import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils import (
    gradient_penalty,
    plot_to_tensorboard,
    save_checkpoint,
    load_checkpoint,
    generate_examples,
)
# from model import Discriminator, Generator
from model_gen_unet import Discriminator, Generator

from math import log2
from tqdm import tqdm
import config

torch.backends.cudnn.benchmarks = True

In [4]:

def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(config.CHANNELS_IMG)],
                [0.5 for _ in range(config.CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset



In [5]:

def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
    tensorboard_step,
    writer,
    scaler_gen,
    scaler_critic,
):
    loop = tqdm(loader, leave=True)

####################

#     for batch_idx, (real, _) in enumerate(loop):

#         real = real.to(config.DEVICE)
#         cur_batch_size = real.shape[0]

#         # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
#         # which is equivalent to minimizing the negative of the expression
#         noise = real
#         with torch.cuda.amp.autocast():
#             fake = gen(noise, alpha, step)
#             critic_real = critic(real, alpha, step)
#             critic_fake = critic(fake.detach(), alpha, step)
#             gp = gradient_penalthttp://localhost:8888/notebooks/Desktop/Workplace/Projects/colorizing_progan/Untitled.ipynb?kernel_name=gpu-pytorch-2022-10#y(critic, real, fake, alpha, step, device=config.DEVICE)
#             loss_critic = (
#                 -(torch.mean(critic_real) - torch.mean(critic_fake))
#                 + config.LAMBDA_GP * gp
#                 + (0.001 * torch.mean(critic_real ** 2))
#             )

#         opt_critic.zero_grad()
#         scaler_critic.scale(loss_critic).backward()
#         scaler_critic.step(opt_critic)
#         scaler_critic.update()

#         # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
#         with torch.cuda.amp.autocast():
#             gen_fake = critic(fake, alpha, step)
#             loss_gen = -torch.mean(gen_fake)

####################

    for batch_idx, (real, _) in enumerate(loop):
        real_gray = real.mean(dim=1, keepdim=True)
        real_gray = real_gray.to(config.DEVICE)
        
        real = real.to(config.DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
#         noise = real_gray
        noise = torch.randn(real_gray.shape).to(config.DEVICE) * 2 - 1
        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + config.LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

####################

#     for batch_idx, (real, _) in enumerate(loop):
#         # Convert RGB to grayscale
#         real_gray = real.mean(dim=1, keepdim=True)

#         real_gray = real_gray.to(config.DEVICE)
#         real = real.to(config.DEVICE)
#         cur_batch_size = real.shape[0]

#         # Noise is now grayscale
#         noise = real_gray

#         with torch.cuda.amp.autocast():
#             fake = gen(noise, alpha, step)

#             # Stack grayscale with RGB for critic
#             critic_real = critic(torch.cat([real_gray, real], dim=1), alpha, step)
#             critic_fake = critic(torch.cat([real_gray, fake.detach()], dim=1), alpha, step)
            
#             real_gp = torch.cat([real_gray, real], dim=1)  # Assuming real_gray contains the gray channel
#             fake_gp = torch.cat([real_gray, fake], dim=1)
#             gp = gradient_penalty(critic, real_gp, fake_gp, alpha, step, device=config.DEVICE)
#             loss_critic = (
#                 -(torch.mean(critic_real) - torch.mean(critic_fake))
#                 + config.LAMBDA_GP * gp
#                 + (0.001 * torch.mean(critic_real ** 2))
#             )

#         opt_critic.zero_grad()
#         scaler_critic.scale(loss_critic).backward()
#         scaler_critic.step(opt_critic)
#         scaler_critic.update()

#         # Generator part remains the same
#         with torch.cuda.amp.autocast():
#             gen_fake = critic(torch.cat([real_gray, fake], dim=1), alpha, step)
#             loss_gen = -torch.mean(gen_fake)

####################

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Alpha update remains the same
        alpha += cur_batch_size / (
            (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

    
        if batch_idx % 50 == 0:
            with torch.no_grad():
#                 fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5

#                 fixed_fakes  = gen(real_gray, alpha, step) * 0.5 + 0.5
#                 fixed_fakes0 = gen(real_gray, 0    , step) * 0.5 + 0.5
#                 fixed_fakes1 = real_gray * 0.5 + 0.5
            
                fixed_fakes  = gen(noise, alpha, step) * 0.5 + 0.5
                fixed_fakes0 = gen(noise, 0    , step) * 0.5 + 0.5
                fixed_fakes1 = noise * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                fixed_fakes0.detach(),
                fixed_fakes1.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha



In [6]:
##########

In [None]:
# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
# but really who cares..
gen = Generator(
    config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
critic = Discriminator(
    config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)

# initialize optimizers and scalers for FP16 training
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()

# for tensorboard plotting
writer = SummaryWriter(f"logs/gan1")

if config.LOAD_MODEL:
    load_checkpoint(
        config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
    )
    load_checkpoint(
        config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
    )

gen.train()
critic.train()

tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5  # start with very low alpha
    loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        tensorboard_step, alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen,
            tensorboard_step,
            writer,
            scaler_gen,
            scaler_critic,
        )

        if config.SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
#     break
    step += 1  # progress to the next img size

Current image size: 8
Epoch [1/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

Current image size: 8
Epoch [1/6]


100%|███████████████████████████████████████████████| 6250/6250 [04:51<00:00, 21.44it/s, gp=0.00572, loss_critic=-.535]


=> Saving checkpoint
=> Saving checkpoint





Epoch [2/6]
=> Saving checkpoint
=> Saving checkpoint


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

Epoch [2/6]


100%|███████████████████████████████████████████████| 6250/6250 [04:44<00:00, 21.97it/s, gp=0.00381, loss_critic=-.155]


=> Saving checkpoint
=> Saving checkpoint





Epoch [3/6]
=> Saving checkpoint
=> Saving checkpoint


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

Epoch [3/6]


100%|███████████████████████████████████████████████| 6250/6250 [04:47<00:00, 21.71it/s, gp=0.00696, loss_critic=-.183]


=> Saving checkpoint
=> Saving checkpoint





Epoch [4/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [4/6]


100%|███████████████████████████████████████████████| 6250/6250 [04:48<00:00, 21.69it/s, gp=0.00306, loss_critic=-.338]


=> Saving checkpoint
=> Saving checkpoint





Epoch [5/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [5/6]


100%|███████████████████████████████████████████████| 6250/6250 [04:48<00:00, 21.65it/s, gp=0.00546, loss_critic=0.137]


=> Saving checkpoint
=> Saving checkpoint





Epoch [6/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [6/6]


100%|████████████████████████████████████████████████| 6250/6250 [04:47<00:00, 21.73it/s, gp=0.00247, loss_critic=0.13]


=> Saving checkpoint
=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Current image size: 16
Epoch [1/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

Current image size: 16
Epoch [1/6]


100%|██████████████████████████████████████████████| 6250/6250 [06:26<00:00, 16.18it/s, gp=0.00805, loss_critic=0.0448]


=> Saving checkpoint
=> Saving checkpoint





Epoch [2/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [2/6]


100%|███████████████████████████████████████████████| 6250/6250 [06:26<00:00, 16.16it/s, gp=0.00235, loss_critic=-.103]



=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
Epoch [3/6]
Epoch [3/6]


100%|███████████████████████████████████████████████| 6250/6250 [06:24<00:00, 16.26it/s, gp=0.00511, loss_critic=-.331]


=> Saving checkpoint
=> Saving checkpoint





Epoch [4/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [4/6]


100%|███████████████████████████████████████████████| 6250/6250 [06:31<00:00, 15.98it/s, gp=0.00369, loss_critic=-.219]


=> Saving checkpoint
=> Saving checkpoint





Epoch [5/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [5/6]


100%|███████████████████████████████████████████████| 6250/6250 [06:33<00:00, 15.87it/s, gp=0.00457, loss_critic=-.274]


=> Saving checkpoint
=> Saving checkpoint





Epoch [6/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
=> Saving checkpoint
Epoch [6/6]


100%|███████████████████████████████████████████████| 6250/6250 [06:33<00:00, 15.88it/s, gp=0.00321, loss_critic=-.325]


=> Saving checkpoint
=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Current image size: 32
Epoch [1/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

Current image size: 32
Epoch [1/6]


100%|███████████████████████████████████████████████| 6250/6250 [11:23<00:00,  9.15it/s, gp=0.00426, loss_critic=0.604]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [2/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [2/6]


100%|███████████████████████████████████████████████| 6250/6250 [11:28<00:00,  9.08it/s, gp=0.00267, loss_critic=0.418]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [3/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [3/6]


100%|███████████████████████████████████████████████| 6250/6250 [11:30<00:00,  9.06it/s, gp=0.00366, loss_critic=0.517]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [4/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [4/6]


100%|███████████████████████████████████████████████| 6250/6250 [11:27<00:00,  9.09it/s, gp=0.00402, loss_critic=-1.02]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [5/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [5/6]


100%|████████████████████████████████████████████████| 6250/6250 [11:25<00:00,  9.11it/s, gp=0.00222, loss_critic=0.56]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [6/6]


  0%|                                                                                         | 0/6250 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [6/6]


100%|███████████████████████████████████████████████| 6250/6250 [11:27<00:00,  9.10it/s, gp=0.00569, loss_critic=-.201]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
Current image size: 64
Epoch [1/6]


  0%|                                                                                        | 0/12500 [00:00<?, ?it/s]

Current image size: 64
Epoch [1/6]


100%|█████████████████████████████████████████████| 12500/12500 [37:11<00:00,  5.60it/s, gp=0.0095, loss_critic=0.0397]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [2/6]


  0%|                                                                                        | 0/12500 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [2/6]


100%|██████████████████████████████████████████████| 12500/12500 [36:59<00:00,  5.63it/s, gp=0.00175, loss_critic=-1.8]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [3/6]
=> Saving checkpoint
Epoch [3/6]


100%|███████████████████████████████████████████████| 12500/12500 [36:59<00:00,  5.63it/s, gp=0.00225, loss_critic=2.7]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [4/6]


  0%|                                                                                        | 0/12500 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [4/6]


100%|█████████████████████████████████████████████| 12500/12500 [39:25<00:00,  5.28it/s, gp=0.00193, loss_critic=-1.33]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [5/6]
=> Saving checkpoint
Epoch [5/6]


100%|██████████████████████████████████████████████| 12500/12500 [38:13<00:00,  5.45it/s, gp=0.00334, loss_critic=1.32]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
Epoch [6/6]


  0%|                                                                                        | 0/12500 [00:00<?, ?it/s]

=> Saving checkpoint
Epoch [6/6]


100%|██████████████████████████████████████████████| 12500/12500 [37:47<00:00,  5.51it/s, gp=0.00732, loss_critic=-1.5]


=> Saving checkpoint





=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
Current image size: 128
Epoch [1/6]


  0%|                                                                                        | 0/12500 [00:00<?, ?it/s]

Current image size: 128
Epoch [1/6]


 48%|██████████████████████▏                       | 6025/12500 [36:30<38:30,  2.80it/s, gp=0.00535, loss_critic=-3.38]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_images(image_batch):
    if len(image_batch.shape) == 3:  # Single image, not batched
        image_batch = np.expand_dims(image_batch, axis=0)

    if isinstance(image_batch, torch.Tensor):
        image_batch = image_batch.detach().cpu().numpy()

    image_batch = np.transpose(image_batch, (0, 2, 3, 1))
    image_batch = (image_batch + 1) / 2.0

    if len(image_batch) == 1:
        plt.imshow(np.squeeze(image_batch[0]))
        plt.axis('off')
        plt.show()
    else:
        fig, axs = plt.subplots(1, len(image_batch), figsize=(15, 15))
        for i, img in enumerate(image_batch):
            if img.shape[-1] == 1:
                img = img.squeeze(-1)
            axs[i].imshow(img)
            axs[i].axis('off')
        plt.show()
# Create a random image batch tensor, assuming shape is (batch_size, channels, height, width)
# and pixel values are in [-1, 1]

In [None]:
loader8, _ = get_loader(4 * 2 ** 1)
loader16, _ = get_loader(4 * 2 ** 2)
loader32, _ = get_loader(4 * 2 ** 3)

In [None]:
images8 = next(iter(loader8))[0]
images16 = next(iter(loader16))[0]
images32 = next(iter(loader32))[0]

In [None]:
images16 = F.interpolate(images32, scale_factor=0.5, mode="nearest")
images8 = F.interpolate(images32, scale_factor=0.25, mode="nearest")

In [None]:
plot_images(images8)
plot_images(images16)
plot_images(images32)

In [None]:
batch = images16.to(config.DEVICE)
fakes1 = gen(batch, 1, 2)

In [None]:
batch = images32.to(config.DEVICE)
fakes2 = gen(batch, 0, 3)

In [None]:
image_batch = fakes1.detach().cpu().numpy()
plot_images(image_batch)
image_batch.shape

In [None]:
image_batch = fakes2.detach().cpu().numpy()
plot_images(image_batch)
image_batch.shape

In [None]:
out41[0]

In [None]:
out42[1]