In [None]:
import torch
import torch.nn as nn
import dynamiqs as dq

class RhoProjector(nn.Module):
    def forward(self, A):
        A = A.to(torch.cfloat)
        rho = 0.5 * (A + A.conj().transpose(-2, -1))
        eigvals, eigvecs = torch.linalg.eigh(rho)
        eigvals = torch.clamp(eigvals, min=0.0).to(torch.cfloat)
        rho_psd = eigvecs @ torch.diag_embed(eigvals) @ eigvecs.conj().transpose(-2, -1)
        trace = rho_psd.diagonal(dim1=-2, dim2=-1).sum(-1)
        rho_psd = rho_psd / trace[..., None, None]
        return rho_psd

    
class DummyModel(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(32 * 32, 2 * dim * dim)  # 2 for real+imag
        self.dim = dim
        self.projector = RhoProjector()

    def forward(self, x):
        batch = x.shape[0]
        out = self.linear(x.view(batch, -1))
        out = out.view(batch, 2, self.dim, self.dim)
        A = out[:, 0] + 1j * out[:, 1]  # Complex-valued matrix
        rho = self.projector(A)
        return rho



In [10]:
dim = 16
model = DummyModel(dim)
x = torch.randn(8, 1, 32, 32)  # batch of 8 Wigner functions
rho_pred = model(x)

print(rho_pred.shape)          # (8, 16, 16)
print(rho_pred[0].trace())     # Should be ~1.0


torch.Size([8, 16, 16])
tensor(1.-4.4409e-16j, grad_fn=<TraceBackward0>)


In [30]:
import torch
from torch.utils.data import Dataset

class CleanWignerDataset(Dataset):
    def __init__(self, dim=40, size=32, n_samples=1000):
        self.dim = dim
        self.size = size
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        import numpy as np
        import dynamiqs as dq

        xvec = np.linspace(-4, 4, self.size)
        yvec = np.linspace(-4, 4, self.size)

        if np.random.rand() < 0.5:
            n = np.random.randint(0, 10)
            rho = dq.fock_dm(self.dim, n)
        else:
            real = np.random.uniform(1.0, 3.0)
            imag = np.random.uniform(0.0, 0.3)
            alpha = real + 1j * imag
            rho = dq.coherent_dm(self.dim, alpha)

        _, _, w_clean = dq.wigner(rho, xvec=xvec, yvec=yvec)

        w_clean = torch.tensor(np.array(w_clean), dtype=torch.float32).unsqueeze(0)
        rho = torch.tensor(np.array(rho), dtype=torch.complex64)

        return w_clean, rho

def fidelity_loss(rho_pred, rho_true, eps=1e-8):
    # Force everything to complex
    rho_pred = rho_pred.to(torch.cfloat)
    rho_true = rho_true.to(torch.cfloat)

    # Hermitize (important for numerical errors)
    rho_pred = 0.5 * (rho_pred + rho_pred.conj().transpose(-2, -1))
    rho_true = 0.5 * (rho_true + rho_true.conj().transpose(-2, -1))

    # sqrt(rho_true)
    eigvals, eigvecs = torch.linalg.eigh(rho_true)
    eigvals = torch.clamp(eigvals, min=eps)
    sqrt_rho = eigvecs @ torch.diag_embed(torch.sqrt(eigvals).to(torch.cfloat)) @ eigvecs.conj().transpose(-2, -1)

    # inner product
    inner = sqrt_rho @ rho_pred @ sqrt_rho

    # sqrt of inner
    eigvals_inner, eigvecs_inner = torch.linalg.eigh(inner)
    eigvals_inner = torch.clamp(eigvals_inner, min=0.0)
    sqrt_inner = eigvecs_inner @ torch.diag_embed(torch.sqrt(eigvals_inner).to(torch.cfloat)) @ eigvecs_inner.conj().transpose(-2, -1)

    # Fidelity
    fidelity = torch.real(sqrt_inner.diagonal(dim1=-2, dim2=-1).sum(-1)) ** 2
    return 1 - fidelity.mean()

def fidelity_proxy_loss(rho_pred, rho_true):
    # Assume rho_pred and rho_true are complex Hermitian, trace-one
    overlap = torch.real((rho_pred.conj() * rho_true).sum(dim=(-2, -1)))
    return 1 - overlap.mean()


In [32]:
from torch.utils.data import random_split, DataLoader

# Full dataset
full_dataset = CleanWignerDataset(dim=15, size=32, n_samples=10000)

# 90% train, 10% val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


In [33]:
def train_model(model, train_loader, val_loader, optimizer, loss_fn, epochs=20, device="cpu"):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for x, rho_true in train_loader:
            x, rho_true = x.to(device), rho_true.to(device)
            rho_pred = model(x)

            loss = loss_fn(rho_pred, rho_true)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # --- Validation ---
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for x_val, rho_val in val_loader:
                x_val, rho_val = x_val.to(device), rho_val.to(device)
                rho_pred_val = model(x_val)
                val_loss += loss_fn(rho_pred_val, rho_val).item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")


In [35]:
model = DummyModel(dim=15)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Choose loss:
# loss_fn = frobenius_loss
loss_fn = fidelity_proxy_loss

train_model(model, train_loader, val_loader, optimizer, loss_fn, epochs=50)


Epoch 01 | Train Loss: 0.022837 | Val Loss: 0.000151
Epoch 02 | Train Loss: 0.000109 | Val Loss: 0.000109
Epoch 03 | Train Loss: 0.000239 | Val Loss: 0.000057
Epoch 04 | Train Loss: 0.000100 | Val Loss: 0.000014
Epoch 05 | Train Loss: 0.000086 | Val Loss: 0.000126


KeyboardInterrupt: 

In [37]:
import numpy as np
import dynamiqs as dq

dim = 15
size = 32
n = 3  # test Fock state |n⟩

# --- Generate input ---
xvec = np.linspace(-4, 4, size)
yvec = np.linspace(-4, 4, size)
rho_true = dq.fock_dm(dim, n)
_, _, w_clean = dq.wigner(rho_true, xvec=xvec, yvec=yvec)

# --- Prepare tensors ---
w_tensor = torch.tensor(np.array(w_clean), dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, 32, 32)
rho_true_tensor = torch.tensor(np.array(rho_true), dtype=torch.complex64).unsqueeze(0)     # (1, dim, dim)

# --- Model prediction ---
model.eval()
model = model.to("cpu")
with torch.no_grad():
    rho_pred = model(w_tensor)

In [38]:
# --- Extract and print diagonal ---
diag = torch.real(torch.diagonal(rho_pred.squeeze(0)))
print("Predicted diagonal of rho:")
print(diag.numpy())

Predicted diagonal of rho:
[3.9189155e-07 1.0800873e-06 7.5538242e-06 9.9997795e-01 2.7833298e-07
 2.1053340e-06 1.1167782e-06 1.3070627e-06 9.5266267e-07 2.0959187e-06
 9.7388443e-07 4.4181979e-07 6.1974487e-07 7.8401229e-07 2.3119180e-06]


In [48]:
n = 11  # test Fock state |n⟩

# --- Generate input ---
xvec = np.linspace(-4, 4, size)
yvec = np.linspace(-4, 4, size)
rho_true = dq.fock_dm(dim, n)
_, _, w_clean = dq.wigner(rho_true, xvec=xvec, yvec=yvec)

# --- Prepare tensors ---
w_tensor = torch.tensor(np.array(w_clean), dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, 32, 32)
rho_true_tensor = torch.tensor(np.array(rho_true), dtype=torch.complex64).unsqueeze(0)     # (1, dim, dim)

# --- Model prediction ---
model.eval()
model = model.to("cpu")
with torch.no_grad():
    rho_pred = model(w_tensor)

In [49]:
# --- Extract and print diagonal ---
diag = torch.real(torch.diagonal(rho_pred.squeeze(0)))
print("Predicted diagonal of rho:")
print(diag.numpy())

Predicted diagonal of rho:
[0.02247453 0.01118544 0.00531338 0.01118751 0.00202209 0.00276185
 0.01937376 0.00459194 0.13395298 0.6242219  0.01238723 0.0163946
 0.0811761  0.01389427 0.03906231]
