In [1]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# DeepSynergy imports
from deepsynergy import decoders
from deepsynergy.utils_training import train_decoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Decoder sanity-check notebook

Each section constructs a *synthetic* conditional distribution $p(x \mid z)$, then trains the corresponding DeepSynergy **decoder** to reproduce that distribution.

If training succeeds, the average cross-entropy returned by the decoder should match the analytic conditional entropy $H(X \mid Z)$.

## 1 · BinaryDecoder — binary-symmetric channel

* $Z \in \{0,1\}$ with $\Pr(Z = 1) = 0.5$  
* A bit-flip occurs with probability $\varepsilon$

The conditional entropy is

$$
H(X \mid Z)\;=\;
-\varepsilon \,\log_2 \varepsilon \;-\; (1-\varepsilon)\,\log_2(1-\varepsilon).
$$

In [2]:
# Parameters
samples   = 5_000
epsilon   = 0.10    # flip probability

# Analytical H(X|Z)
H_theory  = -(epsilon*np.log2(epsilon) + (1-epsilon)*np.log2(1-epsilon))

# Synthetic data
Z = np.random.randint(0, 2, size=(samples, 1))
X = (Z ^ (np.random.rand(samples, 1) < epsilon)).astype(np.float32)

Z = torch.FloatTensor(Z)
X = torch.FloatTensor(X)
dataloader = DataLoader(TensorDataset(Z, X), batch_size=samples)

# Decoder
decoder = decoders.BinaryDecoder(
    nn.Sequential(
        nn.Linear(1, 8), nn.SELU(),
        nn.Linear(8, 8), nn.SELU(),
        nn.Linear(8, 1)
    )
).to(device)

optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)

decoder_results = train_decoder(
    model       = decoder,
    dataloader  = dataloader,
    optimizer   = optim,
    show_progress = True,
    device      = device,
    epochs      = 1000
)

H_decoder = decoder_results['loss'][0]

print(f"H(X|Z)  analytic : {H_theory:.3f} bits")
print(f"H(X|Z)  decoder  : {H_decoder:.3f} bits")

100%|██████████| 1000/1000 [00:44<00:00, 22.25it/s, loss=[0.460641]] 

H(X|Z)  analytic : 0.469 bits
H(X|Z)  decoder  : 0.461 bits





## 2 · CategoricalDecoder — $N$-ary symmetric channel

* $Z$ is uniform on $\{0,\dots,N-1\}$.  
* With probability $\varepsilon$ the output class is replaced
  by a random *wrong* class.

The entropy is

$$
H(X \mid Z)=
-(1-\varepsilon)\,\log_2(1-\varepsilon)
-\varepsilon\,\log_2\!\left(\frac{\varepsilon}{N-1}\right).
$$

In [3]:
# Parameters
samples = 5_000
N = 5          # number of classes
epsilon = 0.20 # error probability

# Analytical H(X|Z)
H_theory = -(1 - epsilon) * np.log2(1 - epsilon) - epsilon * np.log2(epsilon / (N - 1))

# Synthetic data
Z = np.random.randint(N, size=(samples, 1))
X = Z.copy()
flip = np.random.rand(samples, 1) < epsilon
X[flip] = (Z[flip] + np.random.randint(1, N, size=flip.sum())) % N

Z = torch.FloatTensor(Z)
X = torch.FloatTensor(X)
dataloader = DataLoader(TensorDataset(Z, X), batch_size=samples)

# Decoder
decoder = decoders.CategoricalDecoder(
    nn.Sequential(
        nn.Linear(1, 8), nn.SELU(),
        nn.Linear(8, 16), nn.SELU(),
        nn.Linear(16, 8), nn.SELU(),
        nn.Linear(8, N)
    ),
    num_classes = N,
).to(device)

optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)

decoder_results = train_decoder(
    model       = decoder,
    dataloader  = dataloader,
    optimizer   = optim,
    show_progress = True,
    device      = device,
    epochs      = 1000,
)

H_decoder = decoder_results['loss'][0]

print(f"H(X|Z)  analytic : {H_theory:.3f} bits")
print(f"H(X|Z)  decoder  : {H_decoder:.3f} bits")

100%|██████████| 1000/1000 [00:43<00:00, 23.22it/s, loss=[1.205383]]

H(X|Z)  analytic : 1.122 bits
H(X|Z)  decoder  : 1.205 bits





## 3 · GaussianDecoder — scale depends on $|Z|$

* $Z \sim \mathcal N(0,1)$  
* $X \mid Z=z \sim \mathcal N\!\bigl(0,\, z^{2}\bigr)$

The differential entropy is

$$
H(X \mid Z)
= \tfrac12\log_2(2\pi e)
- \tfrac12(\gamma+\log 2)/\log 2
\;\approx\; 1.131\ \text{bits},
$$

where $\gamma$ is the Euler–Mascheroni constant.

In [5]:
# Parameters
samples = 5_000

# Analytical H(X|Z)  (bits)
H_theory = (
    0.5 * np.log(2 * np.pi * np.e)           # ½ log(2πe)
    - 0.5 * (np.log(2) + np.euler_gamma)     # −½(γ + log 2)
) / np.log(2)

# Synthetic data
Z = np.random.randn(samples, 1)
X = np.random.randn(samples, 1) * np.abs(Z)          # σ = |Z|

Z = torch.FloatTensor(Z)
X = torch.FloatTensor(X)
dataloader = DataLoader(TensorDataset(Z, X), batch_size=samples)

# Decoder
decoder = decoders.GaussianDecoder(
    nn.Sequential(
        nn.Linear(1, 8), nn.SELU(),
        nn.Linear(8, 8), nn.SELU(),
    ),
    output_dim = 1,
).to(device)

optim = torch.optim.Adam(decoder.parameters(), lr=3e-3)

decoder_results = train_decoder(
    model         = decoder,
    dataloader    = dataloader,
    optimizer     = optim,
    show_progress = True,
    device        = device,
    epochs        = 1_000,
)

H_decoder = decoder_results['loss'][0]

print(f"H(X|Z)  analytic : {H_theory:.3f} bits")
print(f"H(X|Z)  decoder  : {H_decoder:.3f} bits")

100%|██████████| 1000/1000 [00:44<00:00, 22.23it/s, loss=[1.2163494]]

H(X|Z)  analytic : 1.131 bits
H(X|Z)  decoder  : 1.216 bits





## 4 · GaussianMixtureDecoder — Laplace scale from $Z$

* $Z \sim \mathrm{Exp}(1)$  
* $X \mid Z=z \sim \mathrm{Laplace}(0,\, z)$

Exact entropy:

$$
H(X \mid Z)=
\bigl[1+\log 2-\gamma\bigr]/\log 2
\;\approx\; 1.608\ \text{bits}.
$$

A mixture with $K=5$ Gaussian components should approximate this well.

In [6]:
# Parameters
samples        = 5_000
K              = 5        # number of mixture components

# Analytical H(X|Z)  (bits)
H_theory = (1 + np.log(2) - np.euler_gamma) / np.log(2)

# Synthetic data
Z = np.random.exponential(scale=1.0, size=(samples, 1))
X = np.random.laplace(loc=0.0, scale=Z)               # Laplace(scale = Z)

Z = torch.FloatTensor(Z)
X = torch.FloatTensor(X)
dataloader = DataLoader(TensorDataset(Z, X), batch_size=samples)

# Decoder
decoder = decoders.GaussianMixtureDecoder(
    nn.Sequential(
        nn.Linear(1, 8),  nn.SELU(),
        nn.Linear(8, 16), nn.SELU(),
        nn.Linear(16, 8), nn.SELU(),
    ),
    output_dim     = 1,
    num_components = K,
).to(device)

optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)

decoder_results = train_decoder(
    model         = decoder,
    dataloader    = dataloader,
    optimizer     = optim,
    show_progress = True,
    device        = device,
    epochs        = 1_000,
)

H_decoder = decoder_results['loss'][0]

print(f"H(X|Z)  analytic : {H_theory:.3f} bits")
print(f"H(X|Z)  decoder  : {H_decoder:.3f} bits")

100%|██████████| 1000/1000 [00:46<00:00, 21.34it/s, loss=[1.6496264]]

H(X|Z)  analytic : 1.610 bits
H(X|Z)  decoder  : 1.650 bits



