In [None]:
import numpy as np
import torch 
from typing import Any, NamedTuple

nn = torch.nn
f = nn.functional

In [None]:
class Model_Config(NamedTuple):
    num_heads : Any # How many decoder-classifer pairs
    encoder : Any # Encoder function
    decoder : Any # Decoder function
    head : Any # Classifier function
    input_layer : Any # Task specific input spec
    n_class : Any # Number of Classes
    out_dim : Any # Size of Image
    hard : Any # argmax (T) or softmax (F)

class Encoder_Config(NamedTuple):
    n_class : Any 
    n_dist : Any # Number of categorical distributions
    stack : Any # Internal Structure
    dense_activation : Any # Activation function
    tau : Any # Temperature tf variable

class Decoder_Config(NamedTuple):
    n_class : Any 
    n_dist : Any 
    stack : Any 
    dense_activation : Any
    latent_square : Any # Size of reshaped sampled logits
    out_dim : Any
    tau : Any 

class Head_Config(NamedTuple):
    n_class : Any
    intermediate : Any # Task-specific layers
    stack : Any 
    dense_activation : Any
    in_dim : Any

class Wrapper_Config(NamedTuple):
    model : Any 
    loss : Any 
    optim : Any 
    epochs : Any 
    temp : Any 
    acc_metric : Any 

class Encoder_Output(NamedTuple):
    logits_y : Any
    p_y : Any # Fixed Prior

class Decoder_Output(NamedTuple):
    recons : Any # Reconstruced x
    x_logits : Any 
    gen_y : Any # Generated Logits

class Model_Output(NamedTuple):
    y_pred : Any # Classifer Output
    x_logits : Any # Reconstructed Distribition
    gen_y : Any # Encoder Output

In [None]:
def init_encoder(config):
    def encoder():
        layers = [] 
        in_dim = config.in_dim

        for size in config.stack[:-1]:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_dim, out_channels=size,
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding),
                        nn.BatchNorm2d(size),
                        nn.LeakyReLU()
                )
            )
            in_dim = size
            
        layers.append(
            nn.Sequential(
                nn.Conv2d(in_dim, out_channels=config.stack[-1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding)
            )
        )

        return nn.Sequential(*layers)
    return encoder

def init_decoder(config):
    def decoder():
        layers = [] 

        for i in range(len(config.stack) - 1):
            layers.append(
                nn.Sequential(
                    nn.ConvTranspose2d(config.stack[i], out_channels=config.stack[i+1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding, output_padding=1),
                        nn.BatchNorm2d(config.stack[i+1]),
                        nn.LeakyReLU()
                )
            )

        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(config.stack[-2], out_channels=config.stack[-1],
                                kernel_size=config.kernel, stride=config.stride, padding=config.padding, output_padding=1)
            )
        )

        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(config.stack[-1],
                                   config.stack[-1],
                                   kernel_size=config.kernel,
                                   stride=config.stride,
                                   padding=config.padding,
                                   output_padding=1),
                nn.BatchNorm2d(config.stack[-1]),
                nn.LeakyReLU(),
                nn.Conv2d(config.stack[-1], out_channels=3,
                            kernel_size=config.kernel, padding=config.padding),
                nn.Tanh())
            )

        return nn.Sequential(*layers)
    return decoder

def init_head(config):  
    def head():
        layers = []
        layers.append(nn.Sequential(
            config.base, 
            torch.flatten()
        ))
        for i in range(len(config.stack) - 1):
            layers.append(nn.Sequential(
                torch.nn.LazyLinear(config.stack[i]),
            ))
        layers.append(nn.Sequential(
            torch.nn.LazyLinear(config.stack[-1]),
            nn.softmax()
        ))
        return nn.Sequential(*layers)
    return head

In [None]:
class SequentialVAE(nn.Module):
    eps = 1e-20
    def __init__(self, config) -> None:
        self.n_class = config.n_class
        self.n_dist = config.n_dist
        self.tau = config.tau
        
        self.encoder = config.encoder()
        self.decoder = config.decoder()
        self.head = config.head()

        self.fc_z = nn.Linear(config.latent*4, self.n_dist * self.n_class)
        self.scale = nn.Linear(self.n_dist * self.n_class, config.latent*4)

        
    def set_tau(self, value) -> None:
        self.tau = value

    def encode(self, input):
        latent = self.encoder(input)
        latent = torch.flatten(latent, start_dim=1)

        z = self.fc_z(latent)
        z = z.view(-1, self.n_dist, self.n_class)

        return [z]

    def decode(self, z):
        x = self.scale(z)
        x = x.view(-1, self.latent, 2, 2)

        decoded = self.decoder(x)

        return decoded

    def reparameterize(self, z):
        u = torch.rand_like(z)
        g = -torch.log(-torch.log(u + self.eps) + self.eps)

        logits = f.softmax((z + g) / self.tau, dim=-1)
        return logits.view(-1, self.n_dist * self.n_class)

    def sample(self):
        return None

    def generate(self):
        return None

    def forward(self, input):
        q = self.encode(input)[0]
        z = self.reparameterize(q)
        recons = self.decode(z)
        y_pred = self.predict(recons)
        return [recons, input, q, y_pred]
    
    def predict(self, input):
        return self.head(input)
    
    def loss(self, *args, **kwargs):
        recons = args[0]
        input = args[1]
        q = args[2]
        y_pred = args[3]

        q_p = f.softmax(q, dim=-1)

        recons_loss = f.mse_loss(recons, input, reduction='mean')

        h1 = q_p * torch.log(q_p + self.eps)
        h2 = q_p * np.log(1. / self.n_dist + self.eps)

        kl = q_p * np.log(1. / self.categorical_dim + self.eps)

        loss = self.alpha * recons_loss +  kl

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kl}

In [None]:
class EnsembleVAE(nn.Module):
    eps = 1e-20
    def __init__(self, config) -> None:
        self.n_class = config.n_class
        self.n_dist = config.n_dist
        self.n_head = config.n_head
        self.tau = config.tau

        self.encoder = config.encoder()
        self.decoder = [config.decoder() for i in range(config.n_head)]
        self.head = [config.head() for i in range(config.n_head)]

        self.fc_z = nn.Linear(config.latent*4, self.n_dist * self.n_class)
        self.scale = nn.Linear(self.n_dist * self.n_class, config.latent*4)

        
    def set_tau(self, value) -> None:
        self.tau = value

    def encode(self, input):
        latent = self.encoder(input)
        latent = torch.flatten(latent, start_dim=1)

        z = self.fc_z(latent)
        z = z.view(-1, self.n_dist, self.n_class)

        return [z]

    def decode(self, z):
        x = self.scale(z)
        x = x.view(-1, self.latent, 2, 2)

        decoded = self.decoder(x)

        return decoded

    def reparameterize(self, z):
        u = torch.rand_like(z)
        g = -torch.log(-torch.log(u + self.eps) + self.eps)

        logits = f.softmax((z + g) / self.tau, dim=-1)
        return logits.view(-1, self.n_dist * self.n_class)

    def sample(self):
        return None

    def generate(self):
        return None

    def forward(self, input):
        q = self.encode(input)[0]
        Z = [self.reparameterize(q) for i in range(self.n_head)]
        X = [self.decode(z) for z in Z]
        Y = [self.predict(x) for x in X]
        return [X, input, q, Y]
    
    def predict(self, input):
        return self.head(input)
    
    def loss(self, *args, **kwargs):
        recons = args[0]
        x = args[1]
        q = args[2]
        y = args[3]

        q_p = f.softmax(q, dim=-1)

        recons_loss = f.mse_loss(recons, input, reduction='mean')

        h1 = q_p * torch.log(q_p + self.eps)
        h2 = q_p * np.log(1. / self.n_dist + self.eps)

        kl = q_p * np.log(1. / self.categorical_dim + self.eps)

        loss = self.alpha * recons_loss +  kl

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kl}

In [None]:
class decode_config:
    x = 1
    y = 2
class encode_config:
    x = 1
    y = 2
class head_config:
    x = 1
    y = 2

In [None]:
encoder_fn = init_encoder(encode_config)
decoder_fn = init_encoder(decode_config)
head_fn = init_encoder(head_config)

In [None]:
class model_config:
    x = 1
    y = 2

In [None]:
model = SequentialVAE(model_config)