In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tqdm import tqdm

# training
B = 32 # batch size
D = 10 # dimension
S = 2 # state space

class Model(nn.Module):
    def __init__(self, D, S):
        super().__init__()
        self.embedding = nn.Embedding(S+1, 16)
        self.net = nn.Sequential(
            nn.Linear(17 * D, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, S*D),
        )

    def forward(self, x, t):
        B, D = x.shape
        x_emb = self.embedding(x) # (B, D, 16)
        net_input = torch.cat([x_emb, t[:, None, None].repeat(1, D, 1)], dim=-1).reshape(B, -1) # (B, D * 17)
        return self.net(net_input).reshape(B, D, S) # (B, D, S)

model = Model(D, S)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

losses = []
for _ in tqdm(range(50000)):
    num_ones = torch.randint(0, D+1, (B,))
    x1 = (torch.arange(D)[None, :] < num_ones[:, None]).long()
    # x1 e.g. [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] or [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]

    optimizer.zero_grad()
    t = torch.rand((B,))
    xt = x1.clone()
    uniform_noise = torch.randint(0, S, (B, D))
    corrupt_mask = torch.rand((B, D)) < (1 - t[:, None])
    xt[corrupt_mask] = uniform_noise[corrupt_mask]
    logits = model(xt, t) # (B, D, S)
    loss = F.cross_entropy(logits.transpose(1,2), x1, reduction='mean', ignore_index=-1)

    loss.backward()
    optimizer.step()
    losses.append(loss.item())


In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)
plt.show()

In [None]:
import numpy as np
# Sampling

t = 0.0
dt = 0.001
num_samples = 1000
noise = 1
xt = torch.randint(0, S, (num_samples, D))

while t < 1.0:
    logits = model(xt, t * torch.ones((num_samples,))) # (B, D, S)
    x1_probs = F.softmax(logits, dim=-1) # (B, D, S)

    x1_probs_at_xt = torch.gather(x1_probs, -1, xt[:, :, None]) # (B, D, 1)

    # Don't add noise on the final step
    if t + dt < 1.0:
        N = noise
    else:
        N = 0

    # Calculate the off-diagonal step probabilities
    step_probs = (
        dt * ((1 + N + N * (S - 1) * t ) / (1-t)) * x1_probs + 
        dt * N * x1_probs_at_xt
    ).clamp(max=1.0) # (B, D, S)

    # Calculate the on-diagnoal step probabilities
    # 1) Zero out the diagonal entries
    step_probs.scatter_(-1, xt[:, :, None], 0.0)
    # 2) Calculate the diagonal entries such that the probability row sums to 1
    step_probs.scatter_(-1, xt[:, :, None], (1.0 - step_probs.sum(dim=-1, keepdim=True)).clamp(min=0.0)) 

    xt = Categorical(step_probs).sample() # (B, D)

    t += dt

In [None]:
print(xt[0:10])

In [None]:
# print(samples)
counts = xt.sum(dim=1).float()
plt.hist(counts.numpy(), bins=range(D+2))
plt.show()

In [None]:
import matplotlib.pyplot as plt
print(xt_hist.shape)
plt.plot(xt_hist[:, 0, 2])
plt.show()