In [1]:
import math
import argparse
import numpy as np

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

In [31]:
class Codebook(nn.Module):
    def __init__(self, n_codes, embedding_dim):
        super().__init__()
        self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
        self.register_buffer('N', torch.zeros(n_codes))
        self.register_buffer('z_avg', self.embeddings.data.clone())

        self.n_codes = n_codes
        self.embedding_dim = embedding_dim
        self._need_init = True

    def _tile(self, x):
        d, ew = x.shape
        if d < self.n_codes:
            n_repeats = (self.n_codes + d - 1) // d
            std = 0.01 / np.sqrt(ew)
            x = x.repeat(n_repeats, 1)
            x = x + torch.randn_like(x) * std
        return x

    def _init_embeddings(self, z):
        # z: [b, c, t, h, w]
        self._need_init = False  # 只初始化第一次
        flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)  # [b,t,h,w,c]->[b*t*h*w, c]
        y = self._tile(flat_inputs)

        d = y.shape[0]
        # [n_codes, embed_dim], integer. 随机选择n_codes个作为embedding的初始值
        # 保证在最初的mapping时，每个embedding vector都尽量被用到？
        
        _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] 
        
        if dist.is_initialized():
            dist.broadcast(_k_rand, 0)
        self.embeddings.data.copy_(_k_rand)
        self.z_avg.data.copy_(_k_rand)
        self.N.data.copy_(torch.ones(self.n_codes))

    def forward(self, z):
        # z: [b, c, t, h, w], z_e(x), encoder output
        if self._need_init and self.training:
            self._init_embeddings(z)
        flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [b,t,h,w,c]->[b*t*h*w,c]
        distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
                    - 2 * flat_inputs @ self.embeddings.t() \
                    + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)

        encoding_indices = torch.argmin(distances, dim=1) # [b*t*h*w]
        encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) # [b*t*h*w, n_codes]
        encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) # [b, t,h,w]

        # [b, t, h, w, c] 根据indices获取对应的embedding, z_q(x)
        embeddings = F.embedding(encoding_indices, self.embeddings) 
        embeddings = shift_dim(embeddings, -1, 1)  # [b, c, t, h, w]

        commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) # 将encoder output与对应的embedding拉近

        # EMA codebook update
        if self.training:
            n_total = encode_onehot.sum(dim=0) # [b*t*h*w, n_codes] -> [n_codes]，每个codes被选中的数量
            print("n_total: ", n_total.shape, n_total)
            # 这一步很关键，就是把新得到的 flat_inputs（也就是encoder_output），重新赋予到encoder_sum中
            # 这个过程就是index对应的flat_inputs中的vector进行累加
            
            encode_sum = flat_inputs.t() @ encode_onehot  # [c, n_codes]
            print("encode_sum: ", encode_sum.shape)
            
            if dist.is_initialized():
                dist.all_reduce(n_total)
                dist.all_reduce(encode_sum)

            self.N.data.mul_(0.99).add_(n_total, alpha=0.01) # 每个codes被选中的数量的滑动平均值
            self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) # [n_coders, c] embedding的滑动平均值

            n = self.N.sum() # 所有codes的总数
            weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n # 这算的还是每个codes的滑动平均值呀
            print("weights: ", weights)
            print("n:", n)
            print("self.N: ", self.N)
            
            encode_normalized = self.z_avg / weights.unsqueeze(1) # 除以每个codes被用过的次数，就是归一化的embedding
            self.embeddings.data.copy_(encode_normalized)  # 到这里是不是就结束了，下面的是在干嘛？？？？

            y = self._tile(flat_inputs)
            _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
            if dist.is_initialized():
                dist.broadcast(_k_rand, 0)

            usage = (self.N.view(self.n_codes, 1) >= 1).float()
            self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))

        embeddings_st = (embeddings - z).detach() + z # stright-through

        avg_probs = torch.mean(encode_onehot, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return dict(embeddings=embeddings_st, encodings=encoding_indices,
                    commitment_loss=commitment_loss, perplexity=perplexity)

    def dictionary_lookup(self, encodings):
        embeddings = F.embedding(encodings, self.embeddings)
        return embeddings

In [32]:
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
    n_dims = len(x.shape)
    if src_dim < 0:
        src_dim = n_dims + src_dim
    if dest_dim < 0:
        dest_dim = n_dims + dest_dim

    assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims

    dims = list(range(n_dims))
    del dims[src_dim]

    permutation = []
    ctr = 0
    for i in range(n_dims):
        if i == dest_dim:
            permutation.append(src_dim)
        else:
            permutation.append(dims[ctr])
            ctr += 1
    x = x.permute(permutation)
    if make_contiguous:
        x = x.contiguous()
    return x

### init embedding

In [33]:
z = torch.randn(5, 128, 4, 32, 32)
codebook = Codebook(2000, 128)
out = codebook(z)

n_total:  torch.Size([2000]) tensor([ 7.,  1.,  8.,  ..., 10., 41.,  5.])
encode_sum:  torch.Size([128, 2000])
weights:  tensor([1.0600, 1.0000, 1.0700,  ..., 1.0900, 1.4000, 1.0400])
n: tensor(2184.8000)
self.N:  tensor([1.0600, 1.0000, 1.0700,  ..., 1.0900, 1.4000, 1.0400])


### some details explain

In [13]:
torch.randperm(10)

tensor([2, 8, 4, 5, 0, 7, 1, 9, 3, 6])

In [17]:
indices = torch.LongTensor([1,2,3,3,2])
encode_onehot = F.one_hot(indices, 5)
encode_onehot

tensor([[0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0]])

In [18]:
flat_inputs = torch.ones(5, 8)  # bs, embed_dim[]
flat_inputs

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [21]:
(flat_inputs.t() @ encode_onehot.float()).t()  # [c_codes, embed_dim]
# index对应的vector进行累加

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])