# Train GAN use public data
assume public data is KITTI

In [3]:
import os
import time
import utils
import torch
from datasets.kitti import KITTIdataset
import torchvision
from utils import *
from torch.nn import BCELoss
from torch.autograd import grad
import torchvision
import torchvision.utils as tvls
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from models.gan import Generator
from models.gan import Discriminator
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.utils import save_image
from tqdm.notebook import tqdm

In [4]:

def freeze(net):
    for p in net.parameters():
        p.requires_grad_(False) 

def unfreeze(net):
    for p in net.parameters():
        p.requires_grad_(True)

def gradient_penalty(D ,x, y):
    # interpolation
    shape = [x.size(0)] + [1] * (x.dim() - 1)
    alpha = torch.rand(shape).cuda()
    z = x + alpha * (y - x)
    z = z.cuda()
    z.requires_grad = True

    o = D(z)
    g = grad(o, z, grad_outputs = torch.ones(o.size()).cuda(), create_graph = True)[0].view(z.size(0), -1)
    gp = ((g.norm(p = 2, dim = 1) - 1) ** 2).mean()

    return gp
run_time = time.strftime("%Y-%m-%d_%H-%M-%S",time.localtime())
save_img_dir = f"gan_result/{run_time}/imgs_kitti_gan"
save_model_dir= f"gan_result/{run_time}/models_kitti_gan"
save_log_dir = f"gan_result/{run_time}/log"

os.makedirs(save_model_dir, exist_ok=True)
os.makedirs(save_img_dir, exist_ok=True)

dataset_name = "kitti"


os.makedirs(save_log_dir, exist_ok=True)
writer = SummaryWriter(save_log_dir)


if __name__ == "__main__":
    ###
    # hyper params and settings
    ###

    lr = 2e-4
    batch_size = 4
    epochs = 20
    n_critic = 10

    print(f"---------------------Training {'GAN'}------------------------------")
    transforms = torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(), # to (0 1)
                        torchvision.transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]), # to (-1 1)
                        torchvision.transforms.Resize((128,256))
                        ])
    vis_transform = torchvision.transforms.Compose([
                                                    torchvision.transforms.Normalize([-1,-1,-1],[2,2,2])
                                                    ])
    dataset = KITTIdataset(transforms=transforms)
    dataset_name = 'kitti'

    # vis_transform = torchvision.transforms.Compose([
    #                                                 torchvision.transforms.Normalize([-1,],[2,])
    #                                                 ])


    # dataset = torchvision.datasets.MNIST(r'F:\COMP90055\GMIDA\datas\MNIST_pytorch', train=True, download=True,
    #                          transform=torchvision.transforms.Compose([
    #                            torchvision.transforms.ToTensor(),
    #                            torchvision.transforms.Resize((32,32)),
    #                         #    torchvision.transforms.Normalize(
    #                         #      (0.1307,), (0.3081,))
    #                             torchvision.transforms.Normalize([0.5,],[0.5,])
    #                          ]))
    dataloader = DataLoader(dataset,
                                  shuffle=True,
                                  batch_size=batch_size,
                                  num_workers=0,
                                  pin_memory=False,
                                  drop_last=True)
    # dataset_name = 'mnist'
    z_dim = 256
    G = Generator(3,z_dim)
    DG = Discriminator(3)
    
    G = torch.nn.DataParallel(G).cuda()
    DG = torch.nn.DataParallel(DG).cuda()

    dg_optimizer = torch.optim.Adam(DG.parameters(), lr=lr, betas=(0.5, 0.9))
    g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.9))

    step = 0

    for epoch in range(epochs):
        start = time.time()
        for _,imgs,_ in tqdm(dataloader):
            step += 1
            imgs = imgs.cuda()
            # plt.imshow(imgs[0].permute(1,2,0).cpu().numpy())
            # plt.show()
            b,c,h,w=imgs.size()
            h = h//8
            w = w//8
            freeze(G)
            unfreeze(DG)

            z = torch.randn(b,z_dim,h,w).cuda() #latent vector
            f_imgs = G(z)

            r_logit = DG(imgs)
            f_logit = DG(f_imgs)

            wd = r_logit.mean() - f_logit.mean()  # Wasserstein-1 Distance
            gp = gradient_penalty(DG,imgs.data, f_imgs.data)
            dg_loss = - wd + gp * 10.0
            
            dg_optimizer.zero_grad()
            dg_loss.backward()
            dg_optimizer.step()

            # train G

            if step % n_critic == 0:
                freeze(DG)
                unfreeze(G)
                z = torch.randn(b,z_dim,h,w).cuda()
                f_imgs = G(z)
                logit_dg = DG(f_imgs)
                # calculate g_loss
                g_loss = - logit_dg.mean()
                
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if (step) % 100 == 0:
                z = torch.randn(b,z_dim,h,w).cuda()
                fake_image = G(z)
                fake_image = vis_transform(fake_image)
                grid = make_grid(fake_image,nrow=4)
                writer.add_scalar("g_loss",g_loss,step)
                writer.add_scalar("dg_loss",dg_loss,step) 
                writer.add_image(f'img_gt_pred', grid, step)
                writer.add_scalar("Wasserstein_distance",wd,step)
                writer.add_scalar("gradient_penalty",gp*10,step)

        save_image(fake_image,os.path.join(save_img_dir,f"{dataset_name}_gan_{step}.png"))
        end = time.time()
        interval = end - start
        print("Epoch:%d \t Time:%.2f" % (epoch, interval))

        if epoch % 5 == 0:
            torch.save(G.state_dict(), os.path.join(save_model_dir, f"{dataset_name}_{epoch}_G.pth"))
            torch.save(DG.state_dict(), os.path.join(save_model_dir, f"{dataset_name}_{epoch}_D.pth"))

---------------------Training GAN------------------------------


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1870.0), HTML(value='')))