# Decoding via Discrete Flow Matching

In [15]:
import torch
import matplotlib.pyplot as plt
from torch import nn, Tensor
from typing import Tuple

## Channel Model and Generator Matrix

In [16]:
class BinarySymmetricChannel:
    def __init__(self, p):
        assert ((p >= 0) and (p <= 1))
        # Probability of bit flip
        self.p = p

    def simulate_output(self, x) -> Tensor:
        size = x.size(dim=0)

        # Generate a random binary sequence of length 'size' 
        ber_prob = torch.mul(torch.ones(size), self.p) 
        z = torch.bernoulli(ber_prob)

        # Compute channel output
        y = torch.remainder(torch.add(x, z), 2).long()

        return y

In [162]:
# Right now just writing these matrices down by hand
class GeneratorMatrix:
    Hamming74 = torch.tensor([
        [1, 1, 0, 1],
        [1, 0, 1, 1],
        [1, 0, 0, 0],
        [0, 1, 1, 1],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]
    ])

    ReedMuller13 = torch.tensor([
        [1,0,0,0,0,1],
        [1,0,0,0,1,0],
        [1,0,0,1,0,0],
        [1,0,0,1,1,1],
        [1,0,0,0,0,0],
        [1,0,0,0,1,1],
        [1,0,0,1,0,1],
        [1,0,0,1,1,0],
        [1,0,1,0,0,0],
        [1,0,1,0,1,1],
        [1,0,1,1,0,1],
        [1,0,1,1,1,0],
        [1,0,1,0,0,1],
        [1,0,1,0,1,0],
        [1,0,1,1,0,0],
        [1,0,1,1,1,1],
        [1,1,0,0,0,0],
        [1,1,0,0,1,1],
        [1,1,0,1,0,1],
        [1,1,0,1,1,0],
        [1,1,0,0,0,1],
        [1,1,0,0,1,0],
        [1,1,0,1,0,0],
        [1,1,0,1,1,1],
        [1,1,1,0,0,1],
        [1,1,1,0,1,0],
        [1,1,1,1,0,0],
        [1,1,1,1,1,1],
        [1,1,1,0,0,0],
        [1,1,1,0,1,1],
        [1,1,1,1,0,1],
        [1,1,1,1,1,0]
    ])

In [166]:
# Generates a batch of samples, for both training and evaluation
# Input: batch_size, channel model, generator matrix
# Ouput: tuple of (x_0, x_1)
def generate_source_and_target(batch_size, channel, generator_matrix) -> Tuple[Tensor, Tensor]:
        n = generator_matrix.size(dim = 0) 
        k = generator_matrix.size(dim = 1)
        model_dim = n + k

        x_1 =  torch.round(torch.rand(batch_size, model_dim)).long()
        x_0 = torch.clone(x_1)

        # The first 4 bits are the message bits. Compute codeword then simulate channel output.
        for i in range(0, batch_size):
                # Codeword and channel output
                c = torch.matmul(generator_matrix, x_1[i, range(0,k)]).long()
                y = channel.simulate_output(c)
                
                # Add channel output to x_1
                x_1[i, range(k, k+n)] = y

                # For the source x_0, we add the proper channel output but replace the message bits with random bits
                x_0[i, range(0, k)] = torch.round(torch.rand(k)).long()
                x_0[i, range(k, k+n)] = y
                
        return (x_0, x_1)

# Discrete Flow Matching Module

In [172]:
class DiscreteFlow(nn.Module):
    def __init__(self, dim: int = 11, h: int = 1024, v: int = 128):
        super().__init__()
        self.v = v
        self.embed = nn.Embedding(v, h)
        self.net = nn.Sequential(
            nn.Linear(dim * h + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim * v))
    
    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        return self.net(torch.cat((t[:, None], self.embed(x_t).flatten(1, 2)), -1)).reshape(list(x_t.shape) + [self.v])

## DFM Decoding via Sampling

In [148]:
# Runs the decoder over a batch of inputs
def dfm_decode(x_0, model, vocab_size, num_trials) -> Tensor:
    x_t = x_0
    t = 0.0
    while t < 1.0 - 1e-5:
        p1 = torch.softmax(model(x_t, torch.ones(num_trials) * t), dim=0)
        h = min(0.001, 1.0 - t)
        one_hot_x_t = nn.functional.one_hot(x_t, vocab_size).float()
        u = (p1 - one_hot_x_t) / (1.0 - t)

        # x_t = torch.distributions.Categorical(probs=one_hot_x_t + h * u).sample()

        # Sometimes there are numerical stability issues. Do something quite janky.
        try:
            x_t = torch.distributions.Categorical(probs=one_hot_x_t + h * u).sample()
        except:
            print("Ran into that ValueError again...")
            print("one_hot_x_t:", one_hot_x_t)
            print("h*u: ", h*u)
            print("The full thing:", one_hot_x_t + h * u)

        t += h

    return x_t

# Training

In [173]:
batch_size = 512
vocab_size = 2

p = 0.05
bsc = BinarySymmetricChannel(p)

generator_matrix = GeneratorMatrix.ReedMuller13
n = generator_matrix.size(dim = 0) 
k = generator_matrix.size(dim = 1)
model_dim = n + k

model = DiscreteFlow(v=vocab_size, dim = model_dim)
optim = torch.optim.Adam(model.parameters(), lr=0.001) 

for iter in range(20000):
    # Generate batches of data
    (x_0, x_1) = generate_source_and_target(batch_size, bsc, generator_matrix)
    
    # Training on random time stamps between 0 and 1
    t = torch.rand(batch_size)

    # Assign to x_1 with prob t, and x_0 with prob 1-t
    x_t = torch.where(torch.rand(batch_size,1) <  t[:, None], x_1, x_0)

    # Train model
    logits = model(x_t, t)

    # Compute the target conditional velocity
    loss = nn.functional.cross_entropy(logits.flatten(0, 1), x_1.flatten(0, 1)).mean()
    optim.zero_grad()
    loss.backward()
    optim.step()

    if iter % 100 == 0:
        print("Iteration: ", iter ,", Loss: ", loss)


Iteration:  0 , Loss:  tensor(0.6933, grad_fn=<MeanBackward0>)
Iteration:  100 , Loss:  tensor(0.0006, grad_fn=<MeanBackward0>)
Iteration:  200 , Loss:  tensor(0.0002, grad_fn=<MeanBackward0>)
Iteration:  300 , Loss:  tensor(0.0001, grad_fn=<MeanBackward0>)
Iteration:  400 , Loss:  tensor(0.0001, grad_fn=<MeanBackward0>)


KeyboardInterrupt: 

## Evaluation

In [174]:
# Run 100 test trials
num_trials = 1000

# Generate a fresh batch of samples for evaluation
(x_0, x_1) = generate_source_and_target(num_trials, bsc, generator_matrix)
correctly_decoded = 0
dec = dfm_decode(x_0, model, vocab_size, num_trials)

for i in range(num_trials):
    if torch.equal(dec[i][range(k)], x_1[i][range(k)]):
        correctly_decoded += 1
    
print("Correctly decoded", correctly_decoded ," out of " , num_trials, " trials.")

Correctly decoded 971  out of  1000  trials.
