In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Tuple

class HelperModule(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.build(*args, **kwargs)

    def build(self, *args, **kwargs):
        raise NotImplementedError

def get_parameter_count(net: torch.nn.Module) -> int:
    return sum(p.numel() for p in net.parameters() if p.requires_grad)
def get_device(cpu):
    if cpu or not torch.cuda.is_available(): return torch.device('cpu')
    return torch.device('cuda')

In [21]:
class ReZero(HelperModule):
    def build(self, in_features: int):
        self.layers = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.BatchNorm1d(in_features),
            nn.ReLU(inplace=True),
        )
        self.alpha = nn.Parameter(torch.tensor(0.0))

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.layers(x) * self.alpha + x

In [19]:
class ResidualStack(HelperModule):
    def build(self, in_features: int, nb_layers: int):
        self.stack = nn.Sequential(*[ReZero(in_features) for _ in range(nb_layers)])

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.stack(x)

class Encoder(HelperModule):
    def build(self,
            in_features: int, hidden_features: int, nb_res_layers: int
        ):
        layers = []

        if in_features!=hidden_features:
            layers.append(nn.Linear(in_features, hidden_features))

        layers.append(ResidualStack(hidden_features, nb_res_layers))

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.layers(x)


In [84]:
class Decoder(HelperModule):
    def build(self,
            embed_dim: int, hidden_features: int, out_dim:int, nb_res_layers: int
        ):
        layers = []
        if embed_dim!=hidden_features:
            layers.append(nn.Linear(embed_dim, hidden_features))

        layers.append(ResidualStack(hidden_features, nb_res_layers))

        if out_dim!=hidden_features:
            layers.append(nn.Linear(hidden_features, out_dim))

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.layers(x)

In [79]:
class CodeLayer(HelperModule):
    def build(self, in_features: int, embed_dim: int, nb_entries: int):
        self.linear_in = nn.Linear(in_features, embed_dim)

        self.dim = embed_dim
        self.n_embed = nb_entries
        self.decay = 0.99
        self.eps = 1e-5

        embed = torch.randn(embed_dim, nb_entries, dtype=torch.float32)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(nb_entries, dtype=torch.float32))
        self.register_buffer("embed_avg", embed.clone())

    @torch.cuda.amp.autocast(enabled=False)
    def forward(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, float, torch.LongTensor]:
        x = self.linear_in(x.float())

        dist = (
            x.pow(2).sum(1, keepdim=True)
            - 2 * x @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)

        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = x.transpose(0, 1) @ embed_onehot

            # TODO: Replace this? Or can we simply comment out?
            # dist_fn.all_reduce(embed_onehot_sum)
            # dist_fn.all_reduce(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        diff = (quantize.detach() - x).pow(2).mean()
        quantize = x + (quantize - x).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id: torch.LongTensor) -> torch.FloatTensor:
        return F.embedding(embed_id, self.embed.transpose(0, 1))

In [9]:
class VQVAE(HelperModule):
    def build(self,
            in_channels: int                = 3,
            hidden_channels: int            = 128,
            res_channels: int               = 32,
            nb_res_layers: int              = 2,
            nb_levels: int                  = 3,
            embed_dim: int                  = 64,
            nb_entries: int                 = 512,
            scaling_rates: list[int]        = [8, 4, 2]
        ):
        self.nb_levels = nb_levels
        assert len(scaling_rates) == nb_levels, "Number of scaling rates not equal to number of levels!"

        self.encoders = nn.ModuleList([Encoder(in_channels, hidden_channels, res_channels, nb_res_layers, scaling_rates[0])])
        for i, sr in enumerate(scaling_rates[1:]):
            self.encoders.append(Encoder(hidden_channels, hidden_channels, res_channels, nb_res_layers, sr))

        self.codebooks = nn.ModuleList()
        for i in range(nb_levels - 1):
            self.codebooks.append(CodeLayer(hidden_channels+embed_dim, embed_dim, nb_entries))
        self.codebooks.append(CodeLayer(hidden_channels, embed_dim, nb_entries))

        self.decoders = nn.ModuleList([Decoder(embed_dim*nb_levels, hidden_channels, in_channels, res_channels, nb_res_layers, scaling_rates[0])])
        for i, sr in enumerate(scaling_rates[1:]):
            self.decoders.append(Decoder(embed_dim*(nb_levels-1-i), hidden_channels, embed_dim, res_channels, nb_res_layers, sr))

        self.upscalers = nn.ModuleList()
        for i in range(nb_levels - 1):
            self.upscalers.append(Upscaler(embed_dim, scaling_rates[1:len(scaling_rates) - i][::-1]))

    def forward(self, x):
        encoder_outputs = []
        code_outputs = []
        decoder_outputs = []
        upscale_counts = []
        id_outputs = []
        diffs = []

        for enc in self.encoders:
            if len(encoder_outputs):
                encoder_outputs.append(enc(encoder_outputs[-1]))
            else:
                encoder_outputs.append(enc(x))

        for l in range(self.nb_levels-1, -1, -1):
            codebook, decoder = self.codebooks[l], self.decoders[l]

            if len(decoder_outputs): # if we have previous levels to condition on
                code_q, code_d, emb_id = codebook(torch.cat([encoder_outputs[l], decoder_outputs[-1]], axis=1))
            else:
                code_q, code_d, emb_id = codebook(encoder_outputs[l])
            diffs.append(code_d)
            id_outputs.append(emb_id)

            code_outputs = [self.upscalers[i](c, upscale_counts[i]) for i, c in enumerate(code_outputs)]
            upscale_counts = [u+1 for u in upscale_counts]
            decoder_outputs.append(decoder(torch.cat([code_q, *code_outputs], axis=1)))

            code_outputs.append(code_q)
            upscale_counts.append(0)

        return decoder_outputs[-1], diffs, encoder_outputs, decoder_outputs, id_outputs

    def decode_codes(self, *cs):
        decoder_outputs = []
        code_outputs = []
        upscale_counts = []

        for l in range(self.nb_levels - 1, -1, -1):
            codebook, decoder = self.codebooks[l], self.decoders[l]
            code_q = codebook.embed_code(cs[l]).permute(0, 3, 1, 2)
            code_outputs = [self.upscalers[i](c, upscale_counts[i]) for i, c in enumerate(code_outputs)]
            upscale_counts = [u+1 for u in upscale_counts]
            decoder_outputs.append(decoder(torch.cat([code_q, *code_outputs], axis=1)))

            code_outputs.append(code_q)
            upscale_counts.append(0)

        return decoder_outputs[-1]


In [14]:
import sys
if ".." not in sys.path:
    sys.path.insert(0, "..")

from datasets import OrganoidDataset

data = OrganoidDataset(data_dir='/home/egor/PycharmProjects/deep_dr/data/organoids')

X_train, y_train = data.train
X_val, y_val = data.val

X_train_batches = torch.split(torch.Tensor(X_train), split_size_or_sections=1024)
X_val_batches = torch.split(torch.Tensor(X_val), split_size_or_sections=1024)

In [16]:
X_train_batches[0].shape

torch.Size([1024, 41])

In [22]:
enc = Encoder(in_features = 41, hidden_features = 32, nb_res_layers=3)

In [36]:
x = X_train_batches[0]
x.shape

torch.Size([1024, 41])

In [37]:
x = enc.forward(x)
x.shape

torch.Size([1024, 32])

In [80]:
code = CodeLayer(in_features = 32, embed_dim=16, nb_entries=256)

In [83]:
code(x)[0]

tensor([[-0.3911,  0.6807, -0.2599,  ...,  0.6381,  0.6656,  0.3845],
        [-0.1961,  0.8031, -0.0063,  ...,  0.8630,  0.1975,  0.3984],
        [-0.1961,  0.8031, -0.0063,  ...,  0.8630,  0.1975,  0.3984],
        ...,
        [-0.1961,  0.8031, -0.0063,  ...,  0.8630,  0.1975,  0.3984],
        [-0.3911,  0.6807, -0.2599,  ...,  0.6381,  0.6656,  0.3845],
        [-0.1961,  0.8031, -0.0063,  ...,  0.8630,  0.1975,  0.3984]],
       grad_fn=<AddBackward0>)

In [86]:
dec = Decoder(embed_dim=16, hidden_features=32, out_dim=41, nb_res_layers = 3)

In [88]:
dec(code(x)[0]).shape

torch.Size([1024, 41])

In [26]:
embed = torch.randn(32, 256, dtype=torch.float32)

In [28]:
embed.shape

torch.Size([32, 256])

In [38]:
flatten = x.permute(1,0)

In [39]:
flatten.shape

torch.Size([32, 1024])

In [42]:
x.pow(2).sum(1, keepdim=True).shape

torch.Size([1024, 1])

In [44]:
(x @ embed).shape

torch.Size([1024, 256])

In [45]:
dist = (
            x.pow(2).sum(1, keepdim=True)
            - 2 * x @ embed
            + embed.pow(2).sum(0, keepdim=True)
        )

In [48]:
dist.shape

torch.Size([1024, 256])

In [49]:
_, embed_ind = (-dist).max(1)

In [51]:
embed_ind.shape

torch.Size([1024])

In [54]:
embed_onehot = F.one_hot(embed_ind, 256).type(flatten.dtype)
embed_onehot.shape

torch.Size([1024, 256])

In [62]:
quantize = F.embedding(embed_ind, embed.transpose(0,1))

In [64]:
quantize.shape

torch.Size([1024, 32])

In [65]:
embed_onehot_sum = embed_onehot.sum(0)

In [67]:
embed_onehot_sum.shape

torch.Size([256])

In [68]:
embed_sum = x.transpose(0, 1) @ embed_onehot

In [70]:
embed_sum.shape

torch.Size([32, 256])

In [71]:

diff = (quantize.detach() - x).pow(2).mean()
quantize = x + (quantize - x).detach()

In [77]:
diff

tensor(1.5068, grad_fn=<MeanBackward0>)

In [78]:
embed_ind

tensor([245, 245, 191,  ..., 245, 204, 245])