# Simple walkthrough/tutorial on using Pix2Pix for the classic inverse problem of deep-learning color

### To train code in a scalable fashion, use the trainer like in train.py. This lets you easily test many different configurations for your problem by changing the config.yaml file. The name of the pix2pix game is hyperparameter tuning to avoid mode collapse. The generator and the discriminator need to be evenly matched. Have fun

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

# Load Pix2pix code
from pix2pix.data_demo.coco import COCO, reverse_transform
from pix2pix import initialize_model

# Visualize networks
from torchview import draw_graph
import graphviz
graphviz.set_jupyter_format('png')

In [None]:
# Use the example COCO dataloader included that has a few samples
dataset = COCO(
    root_dir="/home/deanhazineh/Downloads/Pix2Pix/pix2pix/data_demo/coco_2017_train_samples/",
    train_fold="COCO_2017_Train_Samples",
    num_dat=-1
)
train_test_split=0.9

total_count = len(dataset)
train_count = int(train_test_split * total_count)
test_count = total_count - train_count
train_dataset, test_dataset = random_split(dataset, [train_count, test_count])

train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=False) 

xkey = "L"
ykey = "AB"

In [None]:
sample = next(iter(train_dl))
L, AB = sample[xkey], sample[ykey]

print(L.shape, L.min(), L.max())
print(AB.shape, AB.min(), AB.max())

disp_num = 6
fig, ax = plt.subplots(2, disp_num, figsize=(3*disp_num, 6))
for i in range(disp_num):
    gs, rgb = reverse_transform(L[i], AB[i])
    ax[0, i].imshow(gs, cmap='gray')
    ax[1, i].imshow(rgb)

for axi in ax.flatten():
    axi.axis('off')
ax[0,0].set_ylabel("Grayscale")
ax[1,0].set_ylabel("RGB")


In [None]:
model = initialize_model("./config.yaml")

In [None]:
generator = model.generator
model_graph = draw_graph(generator,input_data=torch.rand((1, 1, 256, 256)))
model_graph.visual_graph

In [None]:
discriminator = model.discriminator
cond = torch.rand(1,1,256,256)
targ = torch.rand(1,2,256,256)
model_graph = draw_graph(discriminator, input_data=(cond, targ))
model_graph.visual_graph

In [None]:
# We code out the training loop in full for a tutorial but you can skip this and use the trainer as shown in the demo_train code
import os
import numpy as np
import torch.optim as optim
import itertools  # NEVER USE ITERTOOLS.CYCLE ON TRAINING DATA WITH RANDOM AUGMENTATIONS
from matplotlib import gridspec

def compute_ema(data, alpha=0.1):
    ema = [data[0]]  
    for i in range(1, len(data)):
        ema.append(alpha * data[i] + (1 - alpha) * ema[i - 1])
    return np.array(ema)

def plot_losses(losses):
    lw = 1
    alp = 0.2

    fig, ax = plt.subplots(1,2)

    ax[0].plot(losses["l1_loss"], 'bo', alpha=alp)
    ax[0].plot(compute_ema(losses["l1_loss"]), 'b-', linewidth=lw, label="L1 Loss")
    ax[0].set_title("Generator L1 Loss")
    ax[0].grid(True, which="both", linestyle="--", linewidth=0.5)

    ax[1].plot(losses["loss_D"], 'ko', alpha=alp)
    ax[1].plot(compute_ema(losses["loss_D"]), 'k-', linewidth=lw, label="Discriminator Loss")
    ax[1].plot(losses["gan_loss"], 'ro', alpha=alp)
    ax[1].plot(compute_ema(losses["gan_loss"]), 'r-', linewidth=lw, label="Generator-GAN Loss")
    ax[1].plot([0, len(losses["loss_D"])], [0.69, 0.69], 'k--')
    ax[1].legend()

    fig.tight_layout()  
    plt.savefig("./out_tutorial/losses.png")
    plt.close()
    return

def visualize(real_A, real_B, disp_num, reverse_transform, save_to=None):
    pred_B = model.forward(real_A)
    snum = np.minimum(disp_num, pred_B.shape[0])
    fig, ax = plt.subplots(3, snum, figsize=(3 * snum, 9))
    for i in range(snum):
        gs, rgb = reverse_transform(real_A[i], pred_B[i])
        _, rgb_gt = reverse_transform(real_A[i], real_B[i])
        ax[0, i].imshow(gs, cmap="gray")
        ax[1, i].imshow(rgb)
        ax[2, i].imshow(rgb_gt)

    for axi in ax.flatten():
        axi.axis("off")
    plt.tight_layout()

    if save_to is not None:
        plt.savefig(save_to)
        plt.close()

    return

###

if not os.path.exists("./out_tutorial/"):
    os.makedirs("./out_tutorial/")

valid_iter = itertools.cycle(test_dl) # used for visualization
train_iter = itertools.cycle(train_dl) # used for visualization (Never train with itertools)

optimizer_G = optim.Adam(model.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(model.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

if os.path.exists("./out_tutorial/ckpt_last.ckpt"):
    ckpt_dict = torch.load("./out_tutorial/ckpt_last.ckpt", map_location="cpu")
    model.load_state_dict(ckpt_dict["state_dict"])
    model.to('cuda')

    optimizer_G.load_state_dict(ckpt_dict["optimizer_G_state_dict"])
    optimizer_D.load_state_dict(ckpt_dict["optimizer_D_state_dict"])

    epoch = ckpt_dict["epoch"]
    losses = ckpt_dict["losses"]
    print(f"Loaded checkpoint from epoch {epoch}")
else:
    model.to("cuda")
    epoch = 0
    losses = {
        "loss_D": [],
        "loss_D_real": [],
        "loss_D_fake": [],
        "loss_G": [],
        "gan_loss": [],
        "l1_loss": []
    }

disp_num = 6
max_epochs = 100
snapshot_every_n=10
for ep in np.arange(epoch, max_epochs):
    
    ### Training step
    epoch_loss_D = 0
    epoch_loss_D_real = 0
    epoch_loss_D_fake = 0
    epoch_loss_G = 0
    epoch_gan_loss = 0
    epoch_l1_loss = 0
    ldl = len(train_dl)
    for sample in train_dl:
        real_A = sample[xkey].to(dtype=torch.float32, device='cuda')
        real_B = sample[ykey].to(dtype=torch.float32, device='cuda')
        fake_B = model.generator(real_A)

        model.discriminator.train()
        optimizer_D.zero_grad()
        loss_D, loss_D_real, loss_D_fake = model.compute_discriminator_loss(
            real_A, real_B, fake_B
        )
        loss_D.backward()
        optimizer_D.step()
        epoch_loss_D += loss_D.item()/ldl
        epoch_loss_D_real += loss_D_real.item()/ldl
        epoch_loss_D_fake += loss_D_fake.item()/ldl

        model.generator.train()
        optimizer_G.zero_grad()
        loss_G, gan_loss, l1_loss = model.compute_generator_loss(
            real_A, real_B
        )
        loss_G.backward()
        optimizer_G.step()
        epoch_loss_G += loss_G.item()/ldl
        epoch_gan_loss += gan_loss.item()/ldl
        epoch_l1_loss += l1_loss.item()/ldl

    losses["loss_D"].append( epoch_loss_D)
    losses["loss_D_real"].append( epoch_loss_D_real)
    losses["loss_D_fake"].append( epoch_loss_D_fake)
    losses["loss_G"].append(epoch_loss_G)
    losses["gan_loss"].append(epoch_gan_loss)
    losses["l1_loss"].append(epoch_l1_loss)
    print(f"epoch: {ep} Gen. Loss: {epoch_loss_G:.3f} Disc. Loss: {epoch_loss_D:.3f}")

    if ep % snapshot_every_n == 0:
        plot_losses(losses)

        model.generator.eval()
        with torch.no_grad():
            sample = next(train_iter)
            real_A = sample[xkey].to(dtype=torch.float32, device='cuda')
            real_B = sample[ykey].to(dtype=torch.float32, device='cuda')
            visualize(real_A, real_B, disp_num=disp_num, reverse_transform=reverse_transform, save_to=f"./out_tutorial/train_{ep+1}.png")
            
            sample = next(valid_iter)
            real_A = sample[xkey].to(dtype=torch.float32, device='cuda')
            real_B = sample[ykey].to(dtype=torch.float32, device='cuda')
            visualize(real_A, real_B, disp_num=disp_num, reverse_transform=reverse_transform, save_to=f"./out_tutorial/test_{ep+1}.png")

        state = {
            "epoch": ep+1,
            "state_dict": model.state_dict(),
            "optimizer_D_state_dict": optimizer_D.state_dict(),
            "optimizer_G_state_dict": optimizer_G.state_dict(),
            "losses": losses
        }
        torch.save(state, "./out_tutorial/ckpt_last.ckpt")

