In [None]:
import numpy as np
import torch 
from typing import Any, NamedTuple

from torch.nn import functional as F

nn = torch.nn
f = nn.functional

In [None]:
class Model_Config(NamedTuple):
    num_heads : Any # How many decoder-classifer pairs
    encoder : Any # Encoder function
    decoder : Any # Decoder function
    head : Any # Classifier function
    input_layer : Any # Task specific input spec
    n_class : Any # Number of Classes
    out_dim : Any # Size of Image
    hard : Any # argmax (T) or softmax (F)

class Encoder_Config(NamedTuple):
    in_dim : Any
    n_class : Any 
    n_dist : Any # Number of categorical distributions
    stack : Any # Internal Structure
    tau : Any # Temperature variable

class Decoder_Config(NamedTuple):
    n_class : Any 
    n_dist : Any 
    stack : Any 
    latent_square : Any # Size of reshaped sampled logits
    out_dim : Any
    tau : Any 

class Head_Config(NamedTuple):
    n_class : Any
    base : Any # Task-specific layers
    stack : Any 
    dense_activation : Any
    in_dim : Any

In [None]:
def init_encoder(config):
    def encoder():
        layers = [] 
        in_dim = config.in_dim

        for size in config.stack[:-1]:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_dim, out_channels=size,
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding),
                        nn.BatchNorm2d(size),
                        nn.LeakyReLU()
                )
            )
            in_dim = size
            
        layers.append(
            nn.Sequential(
                nn.Conv2d(in_dim, out_channels=config.stack[-1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding)
            )
        )

        return nn.Sequential(*layers)
    return encoder

def init_decoder(config):
    def decoder():
        layers = [] 

        for i in range(len(config.stack) - 1):
            layers.append(
                nn.Sequential(
                    nn.ConvTranspose2d(config.stack[i], out_channels=config.stack[i+1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding, output_padding=1),
                        nn.BatchNorm2d(config.stack[i+1]),
                        nn.LeakyReLU()
                )
            )

        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(config.stack[-2], out_channels=config.stack[-1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding, output_padding=1)
            )
        )

        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(config.stack[-1],
                                   config.stack[-1],
                                   kernel_size=config.kernel,
                                   stride=config.stride,
                                   padding=config.padding,
                                   output_padding=1),
                nn.BatchNorm2d(config.stack[-1]),
                nn.LeakyReLU(),
                nn.Conv2d(config.stack[-1], out_channels=3,
                            kernel_size=config.kernel, padding=config.padding),
                nn.Tanh())
            )

        return nn.Sequential(*layers)
    return decoder

def init_head(config):  
    def head():
        layers = []
        if config.base:
            layers.append(nn.Sequential(
                config.base, 
                torch.flatten()
            ))
        else: layers.append(nn.Sequential(torch.flatten()))
        for i in range(len(config.stack) - 1):
            layers.append(nn.Sequential(
                nn.LazyLinear(config.stack[i]),
                nn.ReLU()
            ))
        layers.append(nn.Sequential(
            nn.LazyLinear(config.stack[-1]),
            nn.softmax()
        ))
        return nn.Sequential(*layers)
    return head

In [None]:
class CategoricalVAE(nn.Module):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 head_func,
                 categorical_dim: int = 40, # Num classes
                 hidden_dims = None,
                 temperature: float = 0.5,
                 anneal_rate: float = 3e-5,
                 anneal_interval: int = 100, # every 100 batches
                 alpha: float = 30.,
                 **kwargs) -> None:
        super(CategoricalVAE, self).__init__()

        self.latent_dim = latent_dim
        self.categorical_dim = categorical_dim
        self.temp = temperature
        self.min_temp = temperature
        self.anneal_rate = anneal_rate
        self.anneal_interval = anneal_interval
        self.alpha = alpha

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_z = nn.Linear(hidden_dims[-1]*4,
                               self.latent_dim * self.categorical_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim
                                       , hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        self.head = init_head()
        self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))

    def encode(self, input):
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [B x C x H x W]
        :return: (Tensor) Latent code [B x D x Q]
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        z = self.fc_z(result)
        z = z.view(-1, self.latent_dim, self.categorical_dim)
        return [z]

    def decode(self, z):
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x Q]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, z, eps:float = 1e-7):
        """
        Gumbel-softmax trick to sample from Categorical Distribution
        :param z: (Tensor) Latent Codes [B x D x Q]
        :return: (Tensor) [B x D]
        """
        # Sample from Gumbel
        u = torch.rand_like(z)
        g = - torch.log(- torch.log(u + eps) + eps)

        # Gumbel-Softmax sample
        s = F.softmax((z + g) / self.temp, dim=-1)
        s = s.view(-1, self.latent_dim * self.categorical_dim)
        return s


    def forward(self, input, **kwargs):
        q = self.encode(input)[0]
        z = self.reparameterize(q)
        x = self.decode(z)
        y_pred = self.head(x)
        return  [x, input, q, y_pred]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        q = args[2]
        y_pred = args[4]
        y = args[3]

        q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        batch_idx = kwargs['batch_idx']

        # Anneal the temperature at regular intervals
        if batch_idx % self.anneal_interval == 0 and self.training:
            self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
                                   self.min_temp)

        recons_loss =F.mse_loss(recons, input, reduction='mean')
        cce_loss = f.cross_entropy(y, y_pred, reduction='mean')

        # KL divergence between gumbel-softmax distribution
        eps = 1e-7

        # Entropy of the logits
        h1 = q_p * torch.log(q_p + eps)

        # Cross entropy with the categorical distribution
        h2 = q_p * np.log(1. / self.categorical_dim + eps)
        kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)

        # kld_weight = 1.2
        loss = self.alpha * recons_loss + kld_weight * kld_loss + cce_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD': -kld_loss, 'CCE': cce_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs):
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        # [S x D x Q]

        M = num_samples * self.latent_dim
        np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
        np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
        np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim])
        z = torch.from_numpy(np_y)

        # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
        z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device)
        samples = self.decode(z)
        return samples

    def generate(self, x, **kwargs):
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [None]:
N_CLASS = 10
N_DIST = 20
STACK = []
LATENT_SQUARE = 2

In [None]:
head_config = Head_Config()

In [None]:
head_fn = init_encoder(head_config)

In [None]:
class model_config:
    x = 1
    y = 2

In [None]:
model = CategoricalVAE(model_config)

In [None]:
optim = None 
loss = None 
EPOCHS = 0 

In [None]:
for epoch in range(EPOCHS):
    for i, data in train:
        x, y = data
        pred = model.forward(x) + y
        loss_val = model.loss(*pred)