In [None]:
from functools import partial

import numpy as np

In [2]:
import torch 
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed

from torch.optim import Optimizer
from torch.cuda.amp import autocast

In [None]:
from einops import rearrange, repeat, reduce, pack, unpack

In [4]:
from typing import Callable

In [5]:
def pack_one(t, pattern):
  return pack([t], pattern)

In [6]:
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

In [7]:
"""
  Alongside the last dimension of logits, applies argmax
  then one_hot encodes the result.

  Essentially replaces the largest element of the last dimension with 1,
  and all other elements with 0

  Example:
    logits = [[0, 1, 2], [2, 1, 0], [0, 2, 1]]
    The result of argmax will be [2, 0, 1]
    The function will return [2, 0, 1] and [[0, 0, 1], [1, 0, 0], [0, 1, 0]]

"""
def gumbel_sample(
    logits,
    temperature = 1.,
    dim = -1,
    training = True,
):

  dtype, size = logits.dtype, logits.shape[-1]
  sampling_logits = logits

  ind = sampling_logits.argmax(dim = dim)
  one_hot = F.one_hot(ind, size).type(dtype)

  return ind, one_hot  

In [8]:
def l2norm(t):
  """
  Normalizes t by dividing by its L2 norm, alongside its last dimensions 
  t/||t||_2
  """
  return F.normalize(t, p = 2, dim = -1)

In [9]:
"""
  Sample num vectors from a vector of vectors called samples
  If there's not enough vectors in samples, the return values
  May contain repeated vectors
"""
def sample_vectors(samples, num):
  num_samples, device = samples.shape[0], samples.device
  if num_samples >= num:
    indices = torch.randperm(num_samples, device = device)[:num]
  else:
    indices = torch.randint(0, num_samples, (num, ), device = device)

  return samples[indices]

In [10]:
"""
Sampling like in sample_vectors, but the input "samples" is a batch of vectors of vectors
The sampling is applied independently on its vector of vectors
"""
def batched_sample_vectors(samples, num):
  return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)

In [11]:
def batched_bincount(x, *, minlength):
  """
  Used in kmeans to figure out how many samples belong
  to each cluster
  """
  batch, dtype, device = x.shape[0], x.dtype, x.device
  target = torch.zeros(batch, minlength, dtype = dtype, device = device)
  values = torch.ones_like(x)
  # walks through x
  # and does target[x[i]] += values[i]
  target.scatter_add_(-1, x, values)
  return target

In [12]:
def kmeans(
    samples,
    num_clusters,
    num_iters = 10,
    sample_fn = batched_sample_vectors,
):
  """
  Kmeans clustering
  Pick num_clusters means at random
  Do the following num_iters times:
    Assign each element to the closest cluster based on cosine similarity
    Then recompute the means as the normalized arithmetic mean of the
    items in each cluster
  """
  num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device

  means = sample_fn(samples, num_clusters)

  for _ in range(num_iters):
    # dot products between each sample and the current mean
    # dists[i][j] = dot product between sample i and mean j
    dists = samples @ rearrange(means, 'h n d -> h d n')

    # argmax over the dot product dimension
    # to find which of the means the sample belongs to
    buckets = torch.argmax(dists, dim = -1)
    # bins[b][j] = number of items in clusters j of batch b
    bins = batched_bincount(buckets, minlength = num_clusters)

    # some clusters might have 0 elements
    # the number of elements is replaced with 1
    # because later we need to find the mean
    # of the elements in a cluster
    zero_mask = bins == 0
    bins_min_clamped = bins.masked_fill(zero_mask, 1)

    # each cluster has a mean of dimension equal to dim
    new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)

    # sum up the elements in each cluster
    new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)

    # and divide by the number of elements in each cluster
    new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')

    new_means = l2norm(new_means)

    # for clusters with 0 elements
    # we replace the computed new_mean (which will be 0)
    # with the old mean
    means = torch.where(
      rearrange(zero_mask, '... -> ... 1'),
      means,
      new_means
    )

  return means, bins

In [13]:
def noop():
  pass

In [14]:
class VectorQuantize(nn):
  def __init__(
      self,
      dim,
      codebook_size,
      codebook_dim = None,
      heads = 1,
      # always False
      separate_codebook_per_head = False,
      decay = 0.8,
      eps = 1e-5,
      freeze_codebook = False,
      kmeans_init = False,
      kmeans_iters = 10,
      sync_kmeans = True,
      use_cosine_sim = False,
      threshold_ema_dead_code = 0,
      channel_last = True,
      accept_image_fmap = False,
      commitment_weight = 1.,
      commitment_use_cross_entropy_loss = False,
      orthogonal_reg_weight = 0.,
      orthogonal_reg_active_codes_only = False,
      orthogonal_reg_max_codes = None,
      stochastic_sample_code = False,
      sample_codebook_temp = 1.,
      straight_through = False,
      reinmax = False,
      # always None, and becomes False
      # appears connected to distributed computing
      sync_codebook = None,
      sync_affine_param = False,
      # always True
      ema_update = True,
      # always False
      learnable_codebook = False,
      in_place_codebook_topimizer = None,
      affine_param = False,
      affine_param_batch_decay = 0.99,
      affine_param_codebook_decay = 0.9,
      sync_update_v = 0.
  ):
    super().__init__()
    self.dim = dim
    self.heads = heads
    self.separate_codebook_per_head = separate_codebook_per_head

    codebook_dim = codebook_dim if codebook_dim is not None else dim
    codebook_input_dim = codebook_dim * heads

    requires_projection = codebook_input_dim != dim

    self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
    self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()

    self.has_projections = requires_projection

    self.eps = eps
    self.commitment_weight = commitment_weight
    self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss

    self.learnable_codebook = learnable_codebook

    has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
    self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
    self.orthogonal_reg_weight = orthogonal_reg_weight
    self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
    self.orthogonal_reg_max_codes = orthogonal_reg_max_codes

    assert not (ema_update and learnable_codebook), "learnable codebook incompatible with ema update"

    assert 0 <= sync_update_v <= 1.
    assert not (sync_update_v >0. and not learnable_codebook), "learnable codebook must be ON if sync_update_v > 0"

    self.sync_update_v = sync_update_v

    # codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
    codebook_class = CosineSimCodebook

TypeError: module() takes at most 2 arguments (3 given)

In [None]:
class CosineSimCodebook(nn.Module):
  def __init__(
      self,
      dim,
      codebook_size,
      num_codebooks = 1,
      # always True
      kmeans_init = False,
      kmeans_iters = 10,
      sync_kmeans = True,
      decay = 0.8,
      eps = 1e-5,
      threshold_ema_dead_code = 2,
      reset_cluster_size = None,
      # always False
      use_ddp = False,
      learnable_codebook = False,
      gumbel_sample = gumbel_sample,
      sample_codebook_temp = 1.,
      ema_update = True
  ):
    super().__init__()
    self.transform_input = l2norm

    self.ema_update = ema_update
    self.decay = decay

    embed = torch.zeros(num_codebooks, codebook_size, dim)

    self.codebook_size = codebook_size
    self.num_codebooks = num_codebooks

    self.kmeans_iters = kmeans_iters
    self.eps = eps
    self.threshold_ema_dead_code = threshold_ema_dead_code
    self.reset_cluster_size = reset_cluster_size if reset_cluster_size is not None else threshold_ema_dead_code

    self.gumbel_sample = gumbel_sample
    self.sample_codebook_temp = sample_codebook_temp

    self.sample_fn = batched_sample_vectors
    self.kmeans_all_reduce_fn = noop
    self.all_reduce_fn = noop

    self.register_buffer('initted', torch.Tensor([not kmeans_init]))
    self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
    self.register_buffer('embed_avg', embed.clone())

    self.learnable_codebook = learnable_codebook

    self.embed = nn.Parameter(embed)

  @torch.jit.ignore
  def init_embed_(self, data):
    
    # run kmeans clustering on the data
    embed, cluster_size = kmeans(
      data,
      self.codebook_size,
      self.kmeans_iters,
      sample_fn = batched_sample_vectors
    )

    # embed (which are the centroids found by kmeans)
    # are saved as a non-learnable parameter
    # centroid of a cluster = normalized mean of elements in the cluster
    # the normalized sum of the elements is saved in embed_avg
    embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
    self.embed.data.copy_(embed)
    self.embed_avg.data.copy_(embed_sum)
    self.cluster_size.data.copy_(cluster_size)

  @autocast(enabled = False)
  def forward(
    self,
    x,
    # always None
    sample_codebook_temp = None,
    # always None
    mask = None,
    # always False
    freeze_codebook = False,
  ):
    
    # Q: assume x[i][j][k] is the kth elements of the jth vector from the ith batch
    needs_codebook_dim = x.ndim < 4
    
    # appears equal to 1
    sample_codebook_temp = self.sample_codebook_temp

    x = x.float()
    if needs_codebook_dim:
      x = rearrange(x, '... -> 1 ...')

    dtype = x.dtype

    flatten, ps = pack_one(x, 'h * d')

    print(ps)

    # flatten appears to be the same samples, described above as x[i][j][k]

    self.init_embed_(flatten)

    embed = self.embed.detach()

    # dist[n][c] = sum over d of flatten[n][d]*embed[c][d]
    # dist[n][c] = dot product between vector n and centroid c
    dist = einsum('h n d, h c d -> h n c', flatten, embed)

    print(dist)

    embed_ind, embed_onehot = self.gumbel_sample(
      dist, 
      dim = -1, 
      temperature = sample_codebook_temp,
      training = self.training
    )

    print(embed)
    print(embed_ind)
    print(embed_onehot)

    # only seems to add another dimension
    embed_ind = unpack_one(embed_ind, ps, 'h *')

    if self.training:
      # only appears to add another dimension
      unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
      # quantize[h][b][n][d] = sum over c of unpacked_onehot[h][b][n][c] * embed[h][c][d]
      quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)

In [None]:
x = (torch.arange(100)**3 % 99).reshape(1, 20, 5).double()

In [None]:
codebook = CosineSimCodebook(5, 4)

In [None]:
codebook(x)

In [None]:
t = torch.tensor([[2, 1, 3, 3, 1, 1, 0, 3, 2, 1, 3, 3, 1, 2, 1, 1, 2, 3, 3, 3]])

In [None]:
unpack_one(t, ps, 'h *')

In [None]:
ps = [torch.Size([1, 20])]

In [None]:
t = torch.tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 1., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.]]])

In [None]:
t

In [None]:
unpack_one(t, ps, 'h * c')