In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import Config
from utils import seed_all, ensure_dirs, device
from dataset import ImageFolderDataset
from models.ie_gan import Generator, Discriminator, IdentityEncoder

def train():
    cfg = Config()
    seed_all(42)
    ensure_dirs(cfg.out_root, cfg.out_root/"checkpoints")

    dev = device()
    ds = ImageFolderDataset(cfg.data_root/"train", img_size=cfg.img_size, augment=True)
    dl = DataLoader(ds, batch_size=cfg.gan_batch_size, shuffle=True, num_workers=cfg.num_workers)

    G = Generator(z_dim=cfg.z_dim).to(dev)
    D = Discriminator().to(dev)

    # Frozen identity encoder
    id_enc = IdentityEncoder().to(dev)
    for p in id_enc.parameters():
        p.requires_grad = False
    id_enc.eval()

    optG = torch.optim.Adam(G.parameters(), lr=cfg.gan_lr, betas=(0.5, 0.999))
    optD = torch.optim.Adam(D.parameters(), lr=cfg.gan_lr, betas=(0.5, 0.999))

    for epoch in range(cfg.gan_epochs):
        G.train(); D.train()
        pbar = tqdm(dl, desc=f"IE-GAN Epoch {epoch+1}/{cfg.gan_epochs}")

        for real, _ in pbar:
            real = real.to(dev)
            b = real.size(0)

            z = torch.randn(b, cfg.z_dim, device=dev)
            fake = G(z).detach()

            # label smoothing
            real_label = torch.ones(b, 1, device=dev) * (1.0 - cfg.label_smoothing)
            fake_label = torch.zeros(b, 1, device=dev)

            D_real = D(real)
            D_fake = D(fake)

            lossD = F.binary_cross_entropy_with_logits(D_real, real_label) + \
                    F.binary_cross_entropy_with_logits(D_fake, fake_label)

            optD.zero_grad()
            lossD.backward()
            optD.step()

            # Train Generator
            z = torch.randn(b, cfg.z_dim, device=dev)
            gen = G(z)
            D_gen = D(gen)

            adv_loss = F.binary_cross_entropy_with_logits(D_gen, torch.ones(b,1,device=dev))

            with torch.no_grad():
                real_id = id_enc(real)
            gen_id = id_enc(gen)
            id_loss = F.l1_loss(gen_id, real_id)

            cons_loss = gen_id.var(dim=0).mean()

            lossG = adv_loss + cfg.lambda_id * id_loss + 0.1 * cons_loss

            optG.zero_grad()
            lossG.backward()
            optG.step()

            pbar.set_postfix({"D": float(lossD), "G": float(lossG)})

    ckpt = {"G": G.state_dict(), "D": D.state_dict()}
    torch.save(ckpt, cfg.gan_ckpt)
    print("Saved:", cfg.gan_ckpt)

if __name__ == "__main__":
    train()
