In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tqdm import tqdm

# training
B = 32 # batch size
D = 10 # dimension
S = 2 # state space

class Model(nn.Module):
    def __init__(self, D, S):
        super().__init__()
        self.embedding = nn.Embedding(S+1, 16)
        self.net = nn.Sequential(
            nn.Linear(17 * D, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, S*D),
        )

    def forward(self, x, t):
        B, D = x.shape
        x_emb = self.embedding(x) # (B, D, 16)
        net_input = torch.cat([x_emb, t[:, None, None].repeat(1, D, 1)], dim=-1).reshape(B, -1) # (B, D * 17)
        return self.net(net_input).reshape(B, D, S) # (B, D, S)

model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)


losses = []

def sample_p_xt_g_x1(x1, t):
    # x1 (B, D)
    # t (B,)
    # Returns xt (B, D)

    # uniform
    xt = x1.clone()
    uniform_noise = torch.randint(0, S, (B, D))
    corrupt_mask = torch.rand((B, D)) < (1 - t[:, None])
    xt[corrupt_mask] = uniform_noise[corrupt_mask]

    # masking
    # xt = x1.clone()
    # xt[torch.rand((B,D)) < (1 - t[:, None])] = S-1

    return xt


for _ in tqdm(range(50000)):
    num_ones = torch.randint(0, D+1, (B,))
    x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long()
    # x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]

    optimizer.zero_grad()
    t = torch.rand((B,))
    xt = sample_p_xt_g_x1(x1, t)
    logits = model(xt, t) # (B, D, S)
    loss = F.cross_entropy(logits.transpose(1,2), x1, reduction='mean')
    loss.backward()
    optimizer.step()
    losses.append(loss.item())


In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)
plt.show()

In [None]:
import numpy as np
# Sampling

def dt_p_xt_g_xt(x1, t):
    # x1 (B, D)
    # t float
    # returns (B, D, S) for varying x_t value

    # uniform
    x1_onehot = F.one_hot(x1, num_classes=S) # (B, D, S)
    return x1_onehot - (1/S)

    # masking
    # x1_onehot = F.one_hot(x1, num_classes=S) # (B, D, S)
    # M_onehot = F.one_hot(torch.tensor([S-1]), num_classes=S)[None, :, :] # (1, 1, S)
    # return x1_onehot - M_onehot

def p_xt_g_x1(x1, t):
    # x1 (B, D)
    # t float
    # returns (B, D, S) for varying x_t value

    # uniform
    x1_onehot = F.one_hot(x1, num_classes=S) # (B, D, S)
    return t * x1_onehot + (1-t) * (1/S)

    # masking
    # x1_onehot = F.one_hot(x1, num_classes=S) # (B, D, S)
    # M_onehot = F.one_hot(torch.tensor([S-1]), num_classes=S)[None, :, :] # (1, 1, S)
    # return t * x1_onehot + (1-t) * M_onehot


def sample_prior(num_samples, D):
    # uniform
    return torch.randint(0, S, (num_samples, D))

    # masking
    # return (S-1) * torch.ones((num_samples, D)).long()

t = 0.0
dt = 0.001
num_samples = 1000
xt = sample_prior(num_samples, D)

while t < 1.0:
    logits = model(xt, t * torch.ones((num_samples,))) # (B, D, S)
    x1_probs = F.softmax(logits, dim=-1) # (B, D, S)
    x1 = Categorical(x1_probs).sample() # (B, D)


    # Calculate R_t^*
    # For p(x_t | x_1) > 0 and p(j | x_1) > 0
    # R_t^*(x_t, j | x_1) = Relu( dtp(j | x_1) - dtp(x_t | x_1)) / (Z_t * p(x_t | x_1))
    # For p(x_t | x_1) = 0 or p(j | x_1) = 0 we have R_t^* = 0

    # We will ignore issues with diagnoal entries as later on we will set
    # diagnoal probabilities such that the row sums to one later on.

    dt_p_vals = dt_p_xt_g_xt(x1, t) # (B, D, S)
    dt_p_vals_at_xt = dt_p_vals.gather(-1, xt[:, :, None]).squeeze(-1) # (B, D)

    # Numerator of R_t^*
    R_t_numer = F.relu(dt_p_vals - dt_p_vals_at_xt[:, :, None]) # (B, D, S)

    pt_vals = p_xt_g_x1(x1, t) # (B, D, S)
    Z_t = torch.count_nonzero(pt_vals, dim=-1) # (B, D)
    pt_vals_at_xt = pt_vals.gather(-1, xt[:, :, None]).squeeze(-1) # (B, D)

    # Denominator of R_t^*
    R_t_denom = Z_t * pt_vals_at_xt # (B, D)

    R_t = R_t_numer / R_t_denom[:, :, None] # (B, D, S)

    # Set p(x_t | x_1) = 0 or p(j | x_1) = 0 cases to zero
    R_t[ (pt_vals_at_xt == 0.0)[:, :, None].repeat(1, 1, S)] = 0.0
    R_t[ pt_vals == 0.0] = 0.0


    # Calculate the off-diagonal step probabilities
    step_probs = (R_t * dt).clamp(max=1.0) # (B, D, S)


    # Calculate the on-diagnoal step probabilities
    # 1) Zero out the diagonal entries
    step_probs.scatter_(-1, xt[:, :, None], 0.0)
    # 2) Calculate the diagonal entries such that the probability row sums to 1
    step_probs.scatter_(-1, xt[:, :, None], (1.0 - step_probs.sum(dim=-1, keepdim=True)).clamp(min=0.0)) 

    xt = Categorical(step_probs).sample() # (B, D)

    t += dt


In [None]:
print(xt[0:10])

In [None]:
# print(samples)
counts = xt.sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()