In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Uniform
import numpy as np
EPS = 1e-8

class MLPEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super().__init__()
        layers = []
        last = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(last, h))
            layers.append(nn.ReLU(inplace=True))
            last = h
        self.net = nn.Sequential(*layers)
        # output positive concentration parameters \hat{alpha} for each latent dim
        self.alpha_layer = nn.Linear(last, latent_dim)
        self.mu_layer = nn.Linear(last, latent_dim)
        self.logvar_layer = nn.Linear(last, latent_dim)
    def forward(self, x):
        h = self.net(x)
        # softplus to ensure positive alpha_hat; add small bias to avoid zero
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        alpha_hat = F.softplus(self.alpha_layer(h)) + 1e-6
        alpha_hat = alpha_hat.clamp(min=1e-3, max=50.0)
        return alpha_hat, mu, logvar

class BernoulliDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        last = latent_dim
        for h in hidden_dims:
            layers.append(nn.Linear(last, h))
            layers.append(nn.ReLU(inplace=True))
            last = h
        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        logits = self.net(z)
        # return logits (use BCEWithLogitsLoss)
        return logits


class CC_VAE(nn.Module):
    def __init__(self, input_dim, enc_hidden_dims, dec_hidden_dims, latent_dim, prior_lambda=None):
        """
        input_dim: flattened input size (e.g. 28*28)
        enc_hidden_dims: list of encoder hidden sizes
        dec_hidden_dims: list of decoder hidden sizes
        latent_dim: K
        prior_lambda: vector or scalar for CC prior lambda (if scalar, replicate)
        """
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = MLPEncoder(input_dim, enc_hidden_dims, latent_dim)
        self.decoder = BernoulliDecoder(latent_dim, dec_hidden_dims, input_dim)
        if prior_lambda is None:
            # default weak symmetric prior; user can override
            prior_lambda = torch.ones(latent_dim) * 0.98
        elif torch.is_tensor(prior_lambda):
            if prior_lambda.numel() == 1:
                prior_lambda = prior_lambda.repeat(latent_dim)
        else:
            # numeric scalar
            prior_lambda = torch.tensor(float(prior_lambda)).repeat(latent_dim)
        self.register_buffer('prior_lambda', prior_lambda.float())
 
    def sample_cc_from_lambda(self, lambda_hat):
        """
        Ordered rejection sampler for Continuous Categorical CC(Î»)
        Returns z (batch, K), v (batch, K), u (for reproducibility)
        """
        def sample_continuous_bernoulli(theta): 
            u=torch.rand(1).item()
            if torch.isclose(theta.detach().clone(),torch.tensor(1.0)):
                return u
            else:
                return torch.log(u*(theta-1)+1)/torch.log(theta.detach().clone())
        
        if lambda_hat.dim()==1: 
            lambda_hat=lambda_hat.unsqueeze(0)
            batch_size,K=lambda_hat.shape
            out=[]
            for b in range(batch_size):
                lmb=lambda_hat[b]
                order=torch.argsort(-lmb,dim=0)
                l_sorted=lmb[order]
                for _ in range(100000): 
                    x=torch.zeros(K)
                    c=0.0
                    i=1
                    while c<1 and i<K: 
                        theta=l_sorted[i]/(l_sorted[i]+l_sorted[0])
                        x[i]=sample_continuous_bernoulli(theta)
                        c+=x[i].item()
                        i+=1
                        if c<=1: 
                            x[0]=1-torch.sum(x[1:])
                            x=x[torch.argsort(order)]
                            out.append(x)
                            break
                        if len(out)==b+1: break
                    else: raise RuntimeError("Sampler failed to converge")
                return torch.stack(out)
    
    def forward(self, x):
        """
        x: flattened input (batch, input_dim) with values in [0,1] for Bernoulli decoding
        returns: reconstruction logits, z, lambda_hat
        """
        lambda_hat = self.encoder(x)  # (batch, K)
        z = self.sample_cc_from_lambda(lambda_hat)
        logits = self.decoder(z)  # (batch, input_dim)
        return logits, z, lambda_hat
    
    
def multi_gamma_kl(alpha_hat, prior_alpha, reduction='batchmean'):
    """
    KL between MultiGamma(alpha_hat, beta=1) and MultiGamma(prior_alpha, beta=1)
    Per paper (Equation 3):
      KL(Q||P) = sum_k [ log Gamma(alpha_k) - log Gamma(alpha_hat_k) + (alpha_hat_k - alpha_k) * psi(alpha_hat_k) ]
    alpha_hat: (batch, K)
    prior_alpha: (K,) or (batch, K)
    reduction: 'batchmean', 'sum', 'none'
    Returns scalar KL (averaged over batch if batchmean)
    """
    # broadcast prior_alpha to batch if necessary
    if prior_alpha.dim() == 1:
        prior = prior_alpha.unsqueeze(0).expand_as(alpha_hat)
    else:
        prior = prior_alpha
    term1 = torch.lgamma(prior) - torch.lgamma(alpha_hat)
    term2 = (alpha_hat - prior) * torch.digamma(alpha_hat)
    kl_comp = term1 + term2
    kl = kl_comp.sum(dim=1)  # per example sum over K
    if reduction == 'batchmean':
        return kl.mean()
    elif reduction == 'sum':
        return kl.sum()
    else:
        return kl  # per example
    
    
def dirvae_elbo_loss(model, x, reduction='mean'):
    """
    Compute negative ELBO (loss to minimize) for Bernoulli decoder.
    x: (batch, input_dim) values in {0,1} or [0,1]
    returns loss (scalar), recon_loss (scalar), kl (scalar)
    """
    logits, z, alpha_hat, v = model(x)
    # Reconstruction: bernoulli likelihood -> BCEWithLogits
    bce = F.binary_cross_entropy_with_logits(logits, x, reduction='none')
    recon_per_sample = bce.sum(dim=1)  # per example reconstruction negative log-likelihood
    if reduction == 'mean':
        recon_loss = recon_per_sample.mean()
    else:
        recon_loss = recon_per_sample.sum()
    # KL between MultiGamma post (alpha_hat) and prior MultiGamma (prior_alpha)
    kl = multi_gamma_kl(alpha_hat, model.prior_alpha, reduction='batchmean')
    # ELBO = E_q[log p(x|z)] - KL -> loss = -ELBO = recon_loss + KL
    loss = recon_loss + kl
    return loss, recon_loss, kl


# ----------------------------
# Example training loop skeleton
# ----------------------------
if __name__ == "__main__":
    # Quick usage example for MNIST-like data (flattened, binary)
    import torchvision
    import torchvision.transforms as T
    from torch.utils.data import DataLoader
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Simple binarization transform
    #transform = T.Compose([T.ToTensor(), lambda t: (t > 0.5).float(), lambda t: t.view(-1)])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    loader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
    input_dim = 28 * 28
    latent_dim = 50



    model_DIR = DirVAE(input_dim=input_dim,
                   enc_hidden_dims=[500,500],
                   dec_hidden_dims=[500],
                   latent_dim=latent_dim,
                   prior_alpha=0.98).to(device)
    optimizer_DIR = torch.optim.Adam(model_DIR.parameters(), lr=1e-3)
    for epoch in range(1, 50):
        model_DIR.train()
        tot_loss = 0.0
        tot_recon = 0.0
        tot_kl = 0.0
        samlet=0

        for xb, _ in loader:
            xb = xb.to(device)
            optimizer_DIR.zero_grad()
            loss, recon, kl = dirvae_elbo_loss(model_DIR, xb, reduction='mean')
            loss.backward()
            optimizer_DIR.step()
            tot_loss += loss.item() * xb.size(0)
            tot_recon += recon.item() * xb.size(0)
            tot_kl += kl.item() * xb.size(0)

        n = len(loader.dataset)
        print(f"Epoch {epoch:02d} DIR Loss {tot_loss/n:.4f} Recon {tot_recon/n:.4f} KL {tot_kl/n:.4f}")