<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/src/gan/FQGAN/notebooks/VectorQuantizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [0]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [0]:
def count_parameters_float32(model):
    param = sum(p.numel() for p in model.parameters() if p.requires_grad)
    param_mb = param / 1024 / 1024 / 8 * 32
    return param_mb

In [0]:
class VectorQuantizer(nn.Module):

    def __init__(self, emb_dim=64, num_emb=2**10, commitment=0.25):
        super(VectorQuantizer, self).__init__()

        self.emb_dim = emb_dim
        self.num_emb = num_emb
        self.commitment = commitment
        self.embedding = nn.Parameter(torch.randn(emb_dim, num_emb))

        self.register_buffer("cluster_size", torch.zeros(self.num_emb))
        self.register_buffer("ema_embedding", self.embedding.clone())
    
    def forward(self, inputs):
        # [B, C=D, H, W] --> [B, H, W, C=D]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()

        # [B, H, W, D] --> [N(=BxHxW), D]
        flatten = inputs.view(-1, self.emb_dim)

        # distance d(H[N, D], E[D, K]) --> d[N, K]
        # each element show the distance between Hj and Ei
        distance = (
            flatten.pow(2).sum(1, keepdim=True)
            -2 * flatten @ self.embedding
            + self.embedding.pow(2).sum(0, keepdim=True)
        )

        # embedding_idx: [N, K] --> [N, ]
        embedding_idx = torch.argmin(distance, dim=1)
        # embedding_idx: [N, ] --> [B, H, W, ]
        embedding_idx = embedding_idx.view(*inputs_shape[:-1])
        # quantize: [B, H, W, ] --> [B, H, W, D]
        quantize = F.embedding(embedding_idx, self.embedding.transpose(0, 1))

        # loss
        e_latent_loss = F.mse_loss(quantize.detach(), inputs)
        q_latent_loss = F.mse_loss(quantize, inputs.detach())
        loss = q_latent_loss + self.commitment * e_latent_loss

        quantize = inputs + (quantize - inputs).detach()
        quantize = quantize.permute(0, 3, 1, 2).contiguous()

        return quantize, loss, embedding_idx

In [0]:
class VectorQuantizerEMA(nn.Module):

    def __init__(self, emb_dim=64, num_emb=2**10, commitment=1.0, decay=0.9, eps=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self.emb_dim = emb_dim
        self.num_emb = num_emb
        self.commitment = commitment
        self.decay = decay
        self.eps = eps

        embedding = nn.Parameter(torch.randn(emb_dim, num_emb))
        self.register_buffer("embedding", embedding)
        self.register_buffer("cluster_size", torch.zeros(self.num_emb))
        self.register_buffer("ema_embedding", self.embedding.clone())

    def forward(self, inputs):
        # [B, C=D, H, W] --> [B, H, W, C=D]
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()

        # [B, H, W, D] --> [N(=BxHxW), D]
        flatten = inputs.view(-1, self.emb_dim)

        # distance d(H[N, D], E[D, K]) --> d[N, K]
        # each element show the distance between Hj and Ei
        distance = (
            flatten.pow(2).sum(1, keepdim=True)
            -2 * flatten @ self.embedding
            + self.embedding.pow(2).sum(0, keepdim=True)
        )

        # embedding_idx: [N, K] --> [N, ]
        embedding_idx = torch.argmin(distance, dim=1)
        # embedding_onthot: [N, ] --> [N, K]
        embedding_onehot = F.one_hot(embedding_idx, self.num_emb).type(flatten.dtype)
        # embedding_idx: [N, ] --> [B, H, W, ]
        embedding_idx = embedding_idx.view(*inputs_shape[:-1])
        # quantize: [B, H, W, ] --> [B, H, W, D]
        quantize = F.embedding(embedding_idx, self.embedding.transpose(0, 1))

        if self.training:
            self.cluster_size.mul_(self.decay).add_(
                1-self.decay, embedding_onehot.sum(0)
            )
            dw = flatten.transpose(0, 1) @ embedding_onehot
            self.ema_embedding.data.mul_(self.decay).add_(1-self.decay, dw)
            n = self.cluster_size.sum()
            smoother_cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_emb * self.eps) * n
            )
            embedding_norm = self.ema_embedding / smoother_cluster_size.unsqueeze(0)
            self.embedding.data.copy_(embedding_norm)

        # loss
        e_latent_loss = F.mse_loss(quantize.detach(), inputs)
        loss = self.commitment * e_latent_loss

        quantize = inputs + (quantize - inputs).detach()
        quantize = quantize.permute(0, 3, 1, 2).contiguous()

        return quantize, loss, embedding_idx

In [0]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        input = input.view(-1, self.nz, 1, 1)
        output = self.main(input)
        return output


In [24]:
Generator(100, 64, 3)(torch.randn(10, 100)).shape

torch.Size([10, 3, 64, 64])

In [25]:
count_parameters_float32(Generator(100, 64, 3))

13.64404296875

In [0]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf, emb_dim, num_emb):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        
        self.conv_pre1 = nn.Conv2d(ndf * 1, emb_dim, 1, 1, 0)
        self.vq1 = VectorQuantizerEMA(emb_dim, num_emb)
        self.conv_pos1 = nn.Conv2d(emb_dim, ndf * 1, 1, 1, 0)

        self.layer2 = nn.Sequential(
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True))
        
        self.conv_pre2 = nn.Conv2d(ndf * 2, emb_dim, 1, 1, 0)
        self.vq2 = VectorQuantizerEMA(emb_dim, num_emb)
        self.conv_pos2 = nn.Conv2d(emb_dim, ndf * 2, 1, 1, 0)

        self.layer3 = nn.Sequential(
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_pre3 = nn.Conv2d(ndf * 4, emb_dim, 1, 1, 0)
        self.vq3 = VectorQuantizerEMA(emb_dim, num_emb)
        self.conv_pos3 = nn.Conv2d(emb_dim, ndf * 4, 1, 1, 0)

        self.layer4 = nn.Sequential(
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True))
        self.layer5 = nn.Sequential(
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        total_loss = torch.tensor(0.0)

        output = self.layer1(input)
        pre = self.conv_pre1(output)
        quantize, loss, embedding_idx = self.vq1(pre); total_loss += loss
        pos = self.conv_pos1(quantize)

        output = self.layer2(pos)
        pre = self.conv_pre2(output)
        quantize, loss, embedding_idx = self.vq2(pre); total_loss += loss
        pos = self.conv_pos2(quantize)

        output = self.layer3(pos)
        pre = self.conv_pre3(output)
        quantize, loss, embedding_idx = self.vq3(pre); total_loss += loss
        pos = self.conv_pos3(quantize)

        output = self.layer4(pos)
        output = self.layer5(output)
        return output.view(-1, 1).squeeze(1), pre, quantize, loss, embedding_idx

In [27]:
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[0].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[1].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[2].shape, \
# Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[3].shape, \
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[4].shape

torch.Size([10, 8, 8])

In [28]:
Discriminator(3, 64, 128, 2**10)(torch.randn(10, 3, 64, 64))[3]

tensor(0.7769, grad_fn=<MulBackward0>)

In [29]:
count_parameters_float32(Discriminator(3, 64, 128, 2**10))

10.990478515625

In [30]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [37]:
nz = 100
ngf = 64
ndf = 64
nc = 3
emb_dim = 256
num_emb = 2**10

netG = Generator(nz, ngf, nc).to(device)
netG.apply(weights_init)

netD = Discriminator(nc, ndf, emb_dim, num_emb).to(device)
netD.apply(weights_init)

Discriminator(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv_pre1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
  (vq1): VectorQuantizerEMA()
  (conv_pos1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv_pre2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
  (vq2): VectorQuantizerEMA()
  (conv_pos2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
  (layer3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_sl

In [32]:
from torchvision import datasets as dsets
from torchvision import transforms
from torchvision import utils as vutils


dataset = dsets.CIFAR10(
    root="./data", 
    download=True,
    transform=transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)

dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=128,
    shuffle=True, 
    num_workers=int(4))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


In [0]:
criterion = nn.BCEWithLogitsLoss()

fixed_noise = torch.randn(128, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

In [0]:
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [0]:
from pathlib import Path

Path("./output").mkdir(parents=True, exist_ok=True)

In [36]:
n_epochs = 200
n_dis = 4

for epoch in range(n_epochs):
    for i, data in enumerate(dataloader, 0):

        # for i in range(n_dis):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        # Quantize
        output, pre_real, quant_real, loss_quant_real, embedding_idx_real = netD(real_cpu)
        errD_real = criterion(output, label)
        lossD_real = errD_real +  loss_quant_real
        lossD_real.backward()
        D_x = output.sigmoid().mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        # Quantize
        output, pre_fake, quant_fake, loss_quant_fake, embedding_idx_fake = netD(fake.detach())
        errD_fake = criterion(output, label)
        lossD_fake = errD_fake +  loss_quant_fake
        lossD_fake.backward()
        D_G_z1 = output.sigmoid().mean().item()

        # total loss
        errD = lossD_real + lossD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Quantize
        output, pre_G, quant_G, loss_quant_G, embedding_idx_G = netD(fake)
        errG = criterion(output, label)
        lossG = errG + loss_quant_G
        lossG.backward()
        D_G_z2 = output.sigmoid().mean().item()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, n_epochs, i, len(dataloader),
                 errD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))
        
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % "./output",
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/amp_fake_samples_epoch_%03d.png' % ("./output", epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % ("./output", epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % ("./output", epoch))



[0/200][0/391] Loss_D: 3.6452 Loss_G: 5.8925 D(x): 0.0944 D(G(z)): 0.2351 / 0.0029
[0/200][1/391] Loss_D: 6.0402 Loss_G: 0.1530 D(x): 0.0026 D(G(z)): 0.0025 / 0.8924
[0/200][2/391] Loss_D: 2.5503 Loss_G: 1.0516 D(x): 0.9008 D(G(z)): 0.9063 / 0.3639
[0/200][3/391] Loss_D: 1.5445 Loss_G: 0.7530 D(x): 0.3623 D(G(z)): 0.3610 / 0.4909
[0/200][4/391] Loss_D: 1.4692 Loss_G: 0.5629 D(x): 0.4906 D(G(z)): 0.4904 / 0.5939
[0/200][5/391] Loss_D: 1.5069 Loss_G: 1.2051 D(x): 0.5942 D(G(z)): 0.5946 / 0.3125
[0/200][6/391] Loss_D: 1.6231 Loss_G: 0.2715 D(x): 0.3102 D(G(z)): 0.3081 / 0.7949
[0/200][7/391] Loss_D: 1.9098 Loss_G: 1.9501 D(x): 0.7964 D(G(z)): 0.7977 / 0.1484
[0/200][8/391] Loss_D: 2.1626 Loss_G: 0.2819 D(x): 0.1464 D(G(z)): 0.1446 / 0.7870
[0/200][9/391] Loss_D: 1.8794 Loss_G: 1.2055 D(x): 0.7881 D(G(z)): 0.7892 / 0.3126
[0/200][10/391] Loss_D: 1.6237 Loss_G: 0.5244 D(x): 0.3109 D(G(z)): 0.3094 / 0.6177
[0/200][11/391] Loss_D: 1.5295 Loss_G: 0.9238 D(x): 0.6180 D(G(z)): 0.6183 / 0.4143
[0

KeyboardInterrupt: ignored

In [0]:
def normalize_images(tensor):

    min_val = float(tensor.min())
    max_val = float(tensor.max())
    tensor.clamp_(min=min_val, max=max_val)
    tensor.add_(-min_val).div_(max_val-min_val+1e-5)

    images = (
        tensor
        .mul_(255)
        .add_(0.5)
        .clamp_(0, 255)
        .permute(0, 2, 3, 1)
        .to("cpu", dtype=torch.uint8)
        .numpy()
    )

    return images

In [0]:
def calculate_quantizer(model, optimzier, niters, fname):
    model.train()
    with tqdm(range(niters), position=0, leave=True) as pbar:
        for idx in pbar:

            quantize, loss, embedding_idx = model(torch.load(fname).detach())

            model.zero_grad()
            if loss.grad_fn is not None: 
                loss.backward()
            optimizer.step()

            pbar.set_description(f"loss: {loss.item():.6f}")

    return quantize.cpu(), loss.cpu(), embedding_idx.cpu()

In [0]:
def get_histgram(embedding_idx):
    indexes, values = np.unique(embedding_idx.numpy(), return_counts=True)
    return indexes, values