<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/src/gan/FQGAN/notebooks/VectorQuantizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [109]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), "../"))

In [110]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [111]:
import numpy as np
from tqdm import tqdm
from torchvision.utils import make_grid

In [112]:
from config import Config
from datasets import data_utils
from training.logger import Logger
from training.utils import set_seed, count_parameters_float32
from training.metric_log import MetricLog
from models.FQGAN_64 import Generator, Discriminator
from models.losses import ortho_reg, ProbLoss

In [113]:
class VectorQuantizer(nn.Module):

    def __init__(self, emb_dim=64, num_emb=2**10, commitment=0.25):
        super(VectorQuantizer, self).__init__()

        self.emb_dim = emb_dim
        self.num_emb = num_emb
        self.commitment = commitment
        self.embedding = nn.Parameter(torch.randn(emb_dim, num_emb))
    
    def forward(self, inputs):
        # [B, C=D, H, W] --> [B, H, W, C=D]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()

        # [B, H, W, D] --> [N(=BxHxW), D]
        flatten = inputs.view(-1, self.emb_dim)

        # distance d(H[N, D], E[D, K]) --> d[N, K]
        # each element show the distance between Hj and Ei
        distance = (
            flatten.pow(2).sum(1, keepdim=True)
            -2 * flatten @ self.embedding
            + self.embedding.pow(2).sum(0, keepdim=True)
        )

        # embedding_idx: [N, K] --> [N, ]
        embedding_idx = torch.argmin(distance, dim=1)
        # embedding_idx: [N, ] --> [B, H, W, ]
        embedding_idx = embedding_idx.view(*inputs_shape[:-1])
        # quantize: [B, H, W, ] --> [B, H, W, D]
        quantize = F.embedding(embedding_idx, self.embedding.transpose(0, 1))

        # loss
        e_latent_loss = F.mse_loss(quantize.detach(), inputs)
        q_latent_loss = F.mse_loss(quantize, inputs.detach())
        loss = q_latent_loss + self.commitment * e_latent_loss

        quantize = inputs + (quantize - inputs).detach()
        quantize = quantize.permute(0, 3, 1, 2).contiguous()

        return quantize, loss, embedding_idx

In [114]:
class VectorQuantizerEMA(nn.Module):

    def __init__(self, emb_dim=64, num_emb=2**10, commitment=1.0, decay=0.9, eps=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self.emb_dim = emb_dim
        self.num_emb = num_emb
        self.commitment = commitment
        self.decay = decay
        self.eps = eps

        embedding = nn.Parameter(torch.randn(emb_dim, num_emb))
        self.register_buffer("embedding", embedding)
        self.register_buffer("cluster_size", torch.zeros(self.num_emb))
        self.register_buffer("ema_embedding", self.embedding.clone())

    def forward(self, inputs):
        # [B, C=D, H, W] --> [B, H, W, C=D]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()

        # [B, H, W, D] --> [N(=BxHxW), D]
        flatten = inputs.view(-1, self.emb_dim)

        # distance d(H[N, D], E[D, K]) --> d[N, K]
        # each element show the distance between Hj and Ei
        distance = (
            flatten.pow(2).sum(1, keepdim=True)
            -2 * flatten @ self.embedding
            + self.embedding.pow(2).sum(0, keepdim=True)
        )

        # embedding_idx: [N, K] --> [N, ]
        embedding_idx = torch.argmin(distance, dim=1)
        # embedding_onthot: [N, ] --> [N, K]
        embedding_onehot = F.one_hot(embedding_idx, self.num_emb).type(flatten.dtype)
        # embedding_idx: [N, ] --> [B, H, W, ]
        embedding_idx = embedding_idx.view(*inputs_shape[:-1])
        # quantize: [B, H, W, ] --> [B, H, W, D]
        quantize = F.embedding(embedding_idx, self.embedding.transpose(0, 1))

        if self.training:
            self.cluster_size.mul_(self.decay).add_(
                1-self.decay, embedding_onehot.sum(0)
            )
            dw = flatten.transpose(0, 1) @ embedding_onehot
            self.ema_embedding.data.mul_(self.decay).add_(1-self.decay, dw)
            n = self.cluster_size.sum()
            smoother_cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_emb * self.eps) * n
            )
            embedding_norm = self.ema_embedding / smoother_cluster_size.unsqueeze(0)
            self.embedding.data.copy_(embedding_norm)

        # loss
        e_latent_loss = F.mse_loss(quantize.detach(), inputs)
        loss = self.commitment * e_latent_loss

        quantize = inputs + (quantize - inputs).detach()
        quantize = quantize.permute(0, 3, 1, 2).contiguous()

        return quantize, loss, embedding_idx

In [115]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [116]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        input = input.view(-1, self.nz, 1, 1)
        output = self.main(input)
        return output


In [117]:
Generator(100, 64, 3)(torch.randn(10, 100)).shape

torch.Size([10, 3, 64, 64])

In [118]:
count_parameters_float32(Generator(100, 64, 3))

13.64404296875

In [119]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf, emb_dim, num_emb):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        
        # self.conv_pre1 = nn.Conv2d(ndf * 1, emb_dim, 1, 1, 0)
        # self.vq1 = VectorQuantizerEMA(emb_dim, num_emb)
        # self.conv_pos1 = nn.Conv2d(emb_dim, ndf * 1, 1, 1, 0)

        self.layer2 = nn.Sequential(
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True))
        
        # self.conv_pre2 = nn.Conv2d(ndf * 2, emb_dim, 1, 1, 0)
        # self.vq2 = VectorQuantizerEMA(emb_dim, num_emb)
        # self.conv_pos2 = nn.Conv2d(emb_dim, ndf * 2, 1, 1, 0)

        self.layer3 = nn.Sequential(
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_pre3 = nn.Conv2d(ndf * 4, emb_dim, 1, 1, 0)
        self.vq = VectorQuantizerEMA(emb_dim, num_emb)
        self.conv_pos3 = nn.Conv2d(emb_dim, ndf * 4, 1, 1, 0)

        self.layer4 = nn.Sequential(
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True))

        # self.conv_pre4 = nn.Conv2d(ndf * 8, emb_dim, 1, 1, 0)
        # self.vq4 = VectorQuantizerEMA(emb_dim, num_emb)
        # self.conv_pos4 = nn.Conv2d(emb_dim, ndf * 8, 1, 1, 0)

        self.layer5 = nn.Sequential(
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        total_loss = torch.tensor(0.0)

        output = self.layer1(input)
        # pre = self.conv_pre1(output)
        # quantize, loss, embedding_idx = self.vq(pre); total_loss += loss
        # output = self.conv_pos1(quantize)

        output = self.layer2(output)
        # pre = self.conv_pre2(output)
        # quantize, loss, embedding_idx = self.vq(pre); total_loss += loss
        # output = self.conv_pos2(quantize)

        output = self.layer3(output)
        pre = self.conv_pre3(output)
        quantize, loss, embedding_idx = self.vq(pre); total_loss += loss
        output = self.conv_pos3(quantize)

        output = self.layer4(output)
        # pre = self.conv_pre4(output)
        # quantize, loss, embedding_idx = self.vq(pre); total_loss += loss
        # output = self.conv_pos4(quantize)

        output = self.layer5(output)
        return output.view(-1, 1).squeeze(1), pre, quantize, loss, embedding_idx

In [120]:
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[0].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[1].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[2].shape, \
# Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[3].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[4].shape

torch.Size([10, 8, 8])

In [121]:
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[3]

tensor(0.7616, grad_fn=<MulBackward0>)

In [122]:
count_parameters_float32(Discriminator(3, 64, 128, 2**10))

10.80126953125

In [123]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [124]:
nz = 100
ngf = 64
ndf = 64
nc = 3
ortho = 1e-4
emb_dim = 256
num_emb = 2**10
dataset_name = "celeba"
n_epochs = 100
n_dis = 4
batch_size = 128
num_workers = 4

netG = Generator(nz, ngf, nc).to(device)
netG.apply(weights_init)

netD = Discriminator(nc, ndf, emb_dim, num_emb).to(device)
netD.apply(weights_init)

Discriminator(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (layer3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv_pre3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  (vq): VectorQuantizerEMA()
  (conv_pos3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  (layer4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(51

In [125]:
from torchvision import datasets as dsets
from torchvision import transforms
from torchvision import utils as vutils

In [126]:
if dataset_name == "cifar10":
    dataset = dsets.CIFAR10(
        root="./data", 
        download=True,
        transform=transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )

elif dataset_name == "cifar100":
    dataset = dsets.CIFAR100(
        root="./data", 
        download=True,
        transform=transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )

elif dataset_name == "celeba":
    dataset = dsets.ImageFolder(
        root="../../data/celeba", 
        transform=transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )

In [127]:
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=int(num_workers))

In [128]:
inputs, labels = next(iter(dataloader))

print(inputs.shape)
print(labels.shape)

torch.Size([128, 3, 64, 64])
torch.Size([128])


In [129]:
criterion = nn.BCEWithLogitsLoss()

fixed_noise = torch.randn(128, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

In [130]:
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [131]:
import time
from pathlib import Path

output_dir = Path("./output").joinpath(time.strftime("%Y-%m-%d_%H-%M-%S"))
output_dir.mkdir(parents=True, exist_ok=True)

In [132]:
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader, 0):

        # for i in range(n_dis):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        # Quantize
        output, pre_real, quant_real, loss_quant_real, embedding_idx_real = netD(real_cpu)
        errD_real = criterion(output, label)
        lossD_real = errD_real +  loss_quant_real
        lossD_real.backward()
        D_x = output.sigmoid().mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        # Quantize
        output, pre_fake, quant_fake, loss_quant_fake, embedding_idx_fake = netD(fake.detach())
        errD_fake = criterion(output, label)
        lossD_fake = errD_fake +  loss_quant_fake
        lossD_fake.backward()
        D_G_z1 = output.sigmoid().mean().item()

        # total loss
        errD = lossD_real + lossD_fake
        if ortho:
            ortho_reg(netD)

        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        if i % n_dis == 0:
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Quantize
            output, pre_G, quant_G, loss_quant_G, embedding_idx_G = netD(fake)
            errG = criterion(output, label)
            lossG = errG + loss_quant_G
            lossG.backward()
            D_G_z2 = output.sigmoid().mean().item()

            if ortho:
                ortho_reg(netG)
            optimizerG.step()

        
        
        if i % (100*n_dis) == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    % (epoch, n_epochs, i, len(dataloader),
                        errD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))

            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % output_dir,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/amp_fake_samples_epoch_%03d.png' % (output_dir, epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (output_dir, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (output_dir, epoch))

0261 D(x): 0.9960 D(G(z)): 0.3677 / 0.0056
[42/100][800/1583] Loss_D: 0.1342 Loss_G: 3.7202 D(x): 0.9971 D(G(z)): 0.0448 / 0.0413
[42/100][1200/1583] Loss_D: 0.1674 Loss_G: 4.4171 D(x): 0.9938 D(G(z)): 0.0661 / 0.0219
[43/100][0/1583] Loss_D: 0.1426 Loss_G: 4.4738 D(x): 0.9669 D(G(z)): 0.0229 / 0.0203
[43/100][400/1583] Loss_D: 0.1210 Loss_G: 4.6955 D(x): 0.9836 D(G(z)): 0.0155 / 0.0246
[43/100][800/1583] Loss_D: 0.2977 Loss_G: 4.6139 D(x): 0.8309 D(G(z)): 0.0069 / 0.0209
[43/100][1200/1583] Loss_D: 1.7202 Loss_G: 2.1675 D(x): 0.7961 D(G(z)): 0.6513 / 0.2818
[44/100][0/1583] Loss_D: 0.2190 Loss_G: 3.6193 D(x): 0.9582 D(G(z)): 0.0749 / 0.0489
[44/100][400/1583] Loss_D: 0.2252 Loss_G: 5.3293 D(x): 0.8909 D(G(z)): 0.0093 / 0.0116
[44/100][800/1583] Loss_D: 0.3539 Loss_G: 2.4852 D(x): 0.8010 D(G(z)): 0.0067 / 0.1353
[44/100][1200/1583] Loss_D: 0.1646 Loss_G: 5.3116 D(x): 0.9365 D(G(z)): 0.0128 / 0.0087
[45/100][0/1583] Loss_D: 0.1460 Loss_G: 5.3888 D(x): 0.9556 D(G(z)): 0.0145 / 0.0094
[45

In [133]:
def normalize_images(tensor):

    min_val = float(tensor.min())
    max_val = float(tensor.max())
    tensor.clamp_(min=min_val, max=max_val)
    tensor.add_(-min_val).div_(max_val-min_val+1e-5)

    images = (
        tensor
        .mul_(255)
        .add_(0.5)
        .clamp_(0, 255)
        .permute(0, 2, 3, 1)
        .to("cpu", dtype=torch.uint8)
        .numpy()
    )

    return images

In [134]:
def calculate_quantizer(model, optimzier, niters, fname):
    model.train()
    with tqdm(range(niters), position=0, leave=True) as pbar:
        for idx in pbar:

            quantize, loss, embedding_idx = model(torch.load(fname).detach())

            model.zero_grad()
            if loss.grad_fn is not None: 
                loss.backward()
            optimizer.step()

            pbar.set_description(f"loss: {loss.item():.6f}")

    return quantize.cpu(), loss.cpu(), embedding_idx.cpu()

In [135]:
def get_histgram(embedding_idx):
    indexes, values = np.unique(embedding_idx.numpy(), return_counts=True)
    return indexes, values