In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [36]:
import numpy as np, torch, torch.nn as nn, torch.nn.functional as F, math, time, random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0); np.random.seed(0); random.seed(0)
dim=4
allowed_edges = torch.tensor([[0,1],[0,2],[0,3]], dtype=torch.long, device=device)
def op_with_phase(i,j,dim,phi):
    M = torch.zeros(dim,dim, dtype=torch.complex64, device=device)
    M[i,j] = torch.exp(1j*phi)
    M[j,i] = torch.exp(-1j*phi)
    return M
def step_unitary(edge_w, frac, phase):
    phi = math.pi*phase
    H = torch.zeros(dim,dim, dtype=torch.complex64, device=device)
    for e in range(allowed_edges.shape[0]):
        i,j = allowed_edges[e]
        H = H + edge_w[e]*op_with_phase(int(i),int(j),dim,phi)
    theta = math.pi*frac
    A = (-1j*0.5*theta)*H
    return torch.linalg.matrix_exp(A)
def unitary_from_seq(edges_idx, fracs, phases):
    U = torch.eye(dim, dtype=torch.complex64, device=device)
    for e, f, p in zip(edges_idx, fracs, phases):
        w = torch.zeros(allowed_edges.shape[0], device=device); w[int(e)] = 1.0
        U = step_unitary(w, f, p) @ U
    return U
def phase_align(U,V):
    X = V.conj().transpose(-2,-1) @ U
    tr = X.diagonal(dim1=-2,dim2=-1).sum(-1)
    ang = torch.atan2(tr.imag, tr.real).view(-1,1,1)
    return U * torch.exp(-1j*ang)

def frob_loss(U,V):
    d = U.shape[-1]
    U2 = phase_align(U,V)
    diff = U2 - V
    sf = (diff.real**2 + diff.imag**2).sum(dim=(-2,-1))
    return (sf/(d*d)).mean()

def infidelity(U,V):
    d = U.shape[-1]
    X = V.conj().transpose(-2,-1) @ U
    tr = X.diagonal(dim1=-2,dim2=-1).sum(-1).abs()
    return (1.0 - tr/d).mean()
def physics_loss(U_pred, U_true, fracs, w_frob=1.0, w_fid=1.0, w_len=0.1):
    Lf = frob_loss(U_pred, U_true)
    Li = infidelity(U_pred, U_true)
    L1 = fracs.abs().sum()/fracs.numel()
    return w_frob*Lf + w_fid*Li + w_len*L1, (Lf,Li,L1)
def rand_seq(L, phase_set=(0.5,1.5), fmax=2.0):
    edges = torch.randint(0,3,(L,), device=device)
    fracs = fmax*torch.rand(L, device=device)
    phases = torch.tensor(random.choices(phase_set,k=L), device=device)
    return edges, fracs, phases

def make_batch(B,L, fmax=2.0):
    Us=[]; seqs=[]
    for _ in range(B):
        e,f,p = rand_seq(L, fmax=fmax)
        U = unitary_from_seq(e,f,p)
        Us.append(U.unsqueeze(0)); seqs.append((e,f,p))
    U = torch.cat(Us,0)
    return U, seqs


In [37]:
def pack_targets(seqs, L):
    edges_true = torch.stack([s[0][:L] for s in seqs],0).long().to(device)
    fracs_true = torch.stack([s[1][:L] for s in seqs],0).float().to(device)
    phases_true = torch.stack([s[2][:L] for s in seqs],0).float().to(device)
    return edges_true, fracs_true, phases_true


In [38]:
import importlib
TRITON_OK = importlib.util.find_spec("triton") is not None

class InvNet(nn.Module):
    def __init__(self, d=4, Lmax=5):
        super().__init__()
        inp = 2*d*d
        h=1024
        self.net = nn.Sequential(nn.Linear(inp,h), nn.GELU(), nn.Linear(h,h), nn.GELU(), nn.Linear(h,h), nn.GELU())
        self.head_edges = nn.Linear(h, Lmax*3)
        self.head_fracs = nn.Linear(h, Lmax)
        self.head_phases = nn.Linear(h, Lmax)
        self.Lmax=Lmax
    def forward(self, x):
        z = self.net(x)
        logits = self.head_edges(z).view(-1,self.Lmax,3)
        fracs = 2.0*torch.sigmoid(self.head_fracs(z))
        phases = 2.0*torch.sigmoid(self.head_phases(z))
        return logits, fracs, phases

model = InvNet(d=dim, Lmax=5).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4, fused=(device=='cuda'))

if TRITON_OK:
    try:
        model = torch.compile(model)
        print("torch.compile enabled")
    except Exception:
        print("torch.compile not enabled")
else:
    print("No Triton detected; running without torch.compile")


No Triton detected; running without torch.compile


In [39]:
def step_unitary_batched(w, frac, phase):
    B = w.shape[0]
    ephi = torch.exp(1j*torch.pi*phase)
    H = torch.zeros(B, dim, dim, dtype=torch.complex64, device=device)
    for e in range(allowed_edges.shape[0]):
        i, j = allowed_edges[e]
        H[:, i, j] += w[:, e] * ephi
        H[:, j, i] += w[:, e] * ephi.conj()
    theta = torch.pi * frac
    A = (-1j*0.5) * theta.view(-1,1,1) * H
    return torch.linalg.matrix_exp(A)

def pred_unitary_from_outputs(logits, fracs, phases, L, tau=1.0, hard=True):
    B = logits.shape[0]
    U = torch.eye(dim, dtype=torch.complex64, device=device).unsqueeze(0).repeat(B,1,1)
    for t in range(L):
        w = F.gumbel_softmax(logits[:, t, :], tau=tau, hard=hard)
        f = fracs[:, t]
        p = phases[:, t]
        U_step = step_unitary_batched(w, f, p)
        U = U_step @ U
    return U


def train_epoch(curr_L, steps=200, B=64, tau=1.0, w_frob=1.0, w_fid=1.0, w_len=0.1, use_sup=True, lam_edge=2.0, lam_frac=1.0, lam_phase=1.0, lam_ent=0.01, clip=1.0, fmax=2.0):
    model.train()
    mloss=0.0
    for _ in range(steps):
        U_true, seqs = make_batch(B, curr_L)
        x = encode_unitary(U_true).to(device)
        logits, fracs, phases = model(x)
        U_pred = pred_unitary_from_outputs(logits, fracs, phases, curr_L, tau=tau, hard=True)
        L,(Lf,Li,L1) = physics_loss(U_pred, U_true, fracs[:,:curr_L], w_frob, w_fid, w_len)
        probs = logits[:,:curr_L,:].softmax(-1)
        ent = -(probs.clamp_min(1e-8)*probs.clamp_min(1e-8).log()).sum(-1).mean()
        L = L + lam_ent*ent
        if use_sup:
            e_t, f_t, p_t = pack_targets(seqs, curr_L)
            ce = F.cross_entropy(logits[:,:curr_L,:].reshape(-1,3), e_t.reshape(-1))
            mse_f = F.mse_loss(fracs[:,:curr_L], f_t)
            mse_p = F.mse_loss(phases[:,:curr_L], p_t)
            L = L + lam_edge*ce + lam_frac*mse_f + lam_phase*mse_p
        opt.zero_grad(set_to_none=True)
        L.backward()
        if clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        opt.step()
        mloss += L.item()
    return mloss/steps

def eval_epoch(curr_L, B=64, tau=0.5, w_frob=1.0, w_fid=1.0, w_len=0.1):
    model.eval()
    with torch.no_grad():
        U_true,_ = make_batch(B, curr_L)
        x = encode_unitary(U_true)
        logits, fracs, phases = model(x)
        U_pred = pred_unitary_from_outputs(logits, fracs, phases, curr_L, tau=tau, hard=True)
        L,(Lf,Li,L1) = physics_loss(U_pred, U_true, fracs[:,:curr_L], w_frob, w_fid, w_len)
    return L.item(), Lf.item(), Li.item()


In [40]:
cfg = dict(w_frob=1.0, w_fid=1.0, w_len=0.02, epochs_per_L=6, total_L=5, steps_per_epoch=120, B=256, tau_start=0.8, tau_end=0.1, use_sup=True, lam_edge=2.0, lam_frac=1.0, lam_phase=1.0, lam_ent=0.02)
hist = []
for L in range(1, cfg['total_L']+1):
    for e in range(cfg['epochs_per_L']):
        t = 0 if cfg['epochs_per_L']==1 else e/(cfg['epochs_per_L']-1)
        tau = cfg['tau_start'] + (cfg['tau_end']-cfg['tau_start'])*t
        fmax = 0.8 + 1.2*min(1.0, (L-1)/4.0)
        tr = train_epoch(L, steps=cfg['steps_per_epoch'], B=cfg['B'], tau=tau, w_frob=cfg['w_frob'], w_fid=cfg['w_fid'], w_len=cfg['w_len'], use_sup=cfg['use_sup'], lam_edge=cfg['lam_edge'], lam_frac=cfg['lam_frac'], lam_phase=cfg['lam_phase'], lam_ent=cfg['lam_ent'], clip=1.0, fmax=fmax)
        vl, vf, vi = eval_epoch(L, B=cfg['B'], tau=0.05, w_frob=cfg['w_frob'], w_fid=cfg['w_fid'], w_len=cfg['w_len'])
        hist.append((L, e, tr, vl, vf, vi))
        print(f"L={L} epoch={e} train={tr:.4f} val_total={vl:.4f} val_frob={vf:.4f} val_inf={vi:.4f}")


L=1 epoch=0 train=0.7331 val_total=0.0459 val_frob=0.0085 val_inf=0.0169
L=1 epoch=1 train=0.2631 val_total=0.0378 val_frob=0.0058 val_inf=0.0115
L=1 epoch=2 train=0.2351 val_total=0.0302 val_frob=0.0037 val_inf=0.0074
L=1 epoch=3 train=0.2293 val_total=0.0318 val_frob=0.0039 val_inf=0.0078
L=1 epoch=4 train=0.1934 val_total=0.0292 val_frob=0.0032 val_inf=0.0063
L=1 epoch=5 train=0.1942 val_total=0.0347 val_frob=0.0049 val_inf=0.0097
L=2 epoch=0 train=3.5463 val_total=0.4640 val_frob=0.1492 val_inf=0.2984
L=2 epoch=1 train=2.0221 val_total=0.3641 val_frob=0.1157 val_inf=0.2313
L=2 epoch=2 train=1.4064 val_total=0.2678 val_frob=0.0837 val_inf=0.1674
L=2 epoch=3 train=1.2207 val_total=0.2844 val_frob=0.0891 val_inf=0.1781
L=2 epoch=4 train=1.1140 val_total=0.2327 val_frob=0.0718 val_inf=0.1437
L=2 epoch=5 train=1.0521 val_total=0.2182 val_frob=0.0670 val_inf=0.1339
L=3 epoch=0 train=3.5955 val_total=0.6472 val_frob=0.2104 val_inf=0.4207
L=3 epoch=1 train=2.5884 val_total=0.5548 val_frob=