# Fixed schedule 

In [2]:
import argparse, json, math, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def make_schedule(n, passes):
    s=[]
    for _ in range(passes):
        for j in range(1,n):
            s.append((0,j))
    return s

def givens_matrix(n,i,j,theta,phi):
    G=np.eye(n,dtype=np.complex128)
    c=np.cos(theta)
    s=np.sin(theta)
    e=np.exp(1j*phi)
    G[i,i]=c
    G[j,j]=c
    G[i,j]=-e*s
    G[j,i]=np.conj(e)*s
    return G

def unitary_from_params(n,schedule,theta,phi):
    U=np.eye(n,dtype=np.complex128)
    for k,(i,j) in enumerate(schedule):
        U=givens_matrix(n,i,j,theta[k],phi[k])@U
    return U

def canon_phase(U):
    d=np.linalg.det(U)
    a=np.angle(d)/U.shape[0]
    return U*np.exp(-1j*a)

def vec_features(U):
    return np.concatenate([U.real.reshape(-1),U.imag.reshape(-1)],0).astype(np.float32)

def pack_labels(theta,phi):
    cth=np.cos(theta); sth=np.sin(theta); cph=np.cos(phi); sph=np.sin(phi)
    y=np.stack([cth,sth,cph,sph],-1).reshape(-1)
    return y.astype(np.float32)

def unpack_params(y):
    y=y.reshape(-1,4)
    u=y/np.clip(np.linalg.norm(y[:,:2],axis=1,keepdims=True),1e-9,None)
    v=y/np.clip(np.linalg.norm(y[:,2:],axis=1,keepdims=True),1e-9,None)
    cth=u[:,0]; sth=u[:,1]; cph=v[:,2]; sph=v[:,3]
    theta=np.arctan2(sth,cth)
    phi=np.arctan2(sph,cph)
    theta=np.mod(theta,math.pi/2)
    phi=np.mod(phi,2*math.pi)
    return theta,phi

class GenDataset(Dataset):
    def __init__(self,n,passes,size,seed=0):
        self.n=n; self.passes=passes; self.size=size
        rng=np.random.default_rng(seed)
        self.schedule=make_schedule(n,passes)
        L=len(self.schedule)
        self.theta=rng.uniform(0,math.pi/2,size=(size,L)).astype(np.float64)
        self.phi=rng.uniform(0,2*math.pi,size=(size,L)).astype(np.float64)
        X=[]; Y=[]
        for i in range(size):
            U=unitary_from_params(n,self.schedule,self.theta[i],self.phi[i])
            U=canon_phase(U)
            X.append(vec_features(U))
            Y.append(pack_labels(self.theta[i],self.phi[i]))
        self.X=np.stack(X,0)
        self.Y=np.stack(Y,0)
    def __len__(self): return self.size
    def __getitem__(self,idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

class MLP(nn.Module):
    def __init__(self,in_dim,out_dim,h=512):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Linear(h,h),
            nn.ReLU(),
            nn.Linear(h,out_dim)
        )
    def forward(self,x): return self.net(x)

def train_model(n,passes,epochs,train_size,val_size,batch,lr,seed,device):
    torch.manual_seed(seed); np.random.seed(seed)
    schedule=make_schedule(n,passes)
    L=len(schedule)
    train_ds=GenDataset(n,passes,train_size,seed)
    val_ds=GenDataset(n,passes,val_size,seed+1)
    in_dim=2*n*n; out_dim=4*L
    model=MLP(in_dim,out_dim).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=lr)
    best=float('inf'); best_state=None
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True,drop_last=True)
    val_loader=DataLoader(val_ds,batch_size=batch,shuffle=False)
    for ep in range(1,epochs+1):
        model.train()
        tot=0.0; cnt=0
        for xb,yb in train_loader:
            xb=xb.to(device); yb=yb.to(device)
            pred=model(xb)
            loss=F.mse_loss(pred,yb)
            opt.zero_grad(); loss.backward(); opt.step()
            tot+=loss.item()*xb.size(0); cnt+=xb.size(0)
        model.eval()
        vtot=0.0; vcnt=0
        with torch.no_grad():
            for xb,yb in val_loader:
                xb=xb.to(device); yb=yb.to(device)
                pred=model(xb)
                vtot+=F.mse_loss(pred,yb).item()*xb.size(0); vcnt+=xb.size(0)
        vloss=vtot/vcnt
        if vloss<best:
            best=vloss
            best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}
    model.load_state_dict(best_state)
    meta={'n':n,'passes':passes,'schedule':schedule}
    torch.save({'state_dict':model.state_dict(),'meta':meta},'givens_predictor.pt')
    return model,meta

def reconstruct_error(U,Uhat):
    n=U.shape[0]
    A=U.conj().T@Uhat
    fid=np.abs(np.trace(A))/n
    err=np.linalg.norm(U-Uhat,'fro')/np.linalg.norm(U,'fro')
    return float(err),float(fid)

def predict_params(model,meta,U):
    n=meta['n']; schedule=meta['schedule']
    U=canon_phase(U)
    x=vec_features(U)
    with torch.no_grad():
        y=model(torch.from_numpy(x).unsqueeze(0)).cpu().numpy().reshape(-1)
    theta,phi=unpack_params(y)
    return theta,phi

def synthesize(model_path,U,refine=False,steps=50):
    ckpt=torch.load(model_path,map_location='cpu')
    meta=ckpt['meta']
    n=meta['n']; schedule=meta['schedule']
    in_dim=2*n*n; out_dim=4*len(schedule)
    model=MLP(in_dim,out_dim)
    model.load_state_dict(ckpt['state_dict']); model.eval()
    theta,phi=predict_params(model,meta,U)
    Uhat=unitary_from_params(n,schedule,theta,phi)
    if refine:
        th=torch.tensor(theta,dtype=torch.float64,requires_grad=True)
        ph=torch.tensor(phi,dtype=torch.float64,requires_grad=True)
        target=torch.from_numpy(canon_phase(U))
        opt=torch.optim.LBFGS([th,ph],max_iter=steps,history_size=10,line_search_fn='strong_wolfe')
        def f():
            opt.zero_grad()
            Uh=torch.eye(n,dtype=torch.complex128)
            for k,(i,j) in enumerate(schedule):
                c=torch.cos(th[k]); s=torch.sin(th[k]); e=torch.exp(1j*ph[k])
                G=torch.eye(n,dtype=torch.complex128)
                G[i,i]=c; G[j,j]=c; G[i,j]=-e*s; G[j,i]=torch.conj(e)*s
                Uh=G@Uh
            L=torch.norm(target-Uh)**2
            L.backward()
            return L
        opt.step(f)
        theta=th.detach().numpy(); phi=ph.detach().numpy()
        Uhat=unitary_from_params(n,schedule,theta,phi)
    e,f=reconstruct_error(canon_phase(U),canon_phase(Uhat))
    seq=[{'pair':schedule[k],'theta':float(theta[k]),'phi':float(phi[k])} for k in range(len(schedule))]
    return {'error_fro':e,'fidelity':f,'sequence':seq,'n':n,'passes':meta['passes']}

p=argparse.ArgumentParser()
p.add_argument('--n',type=int,default=5)
p.add_argument('--passes',type=int,default=3)
p.add_argument('--epochs',type=int,default=20)
p.add_argument('--train_size',type=int,default=20000)
p.add_argument('--val_size',type=int,default=2000)
p.add_argument('--batch',type=int,default=256)
p.add_argument('--lr',type=float,default=1e-3)
p.add_argument('--seed',type=int,default=0)
p.add_argument('--device',type=str,default='cuda' if torch.cuda.is_available() else 'cpu')
p.add_argument('--demo',action='store_true')
p.add_argument('--refine',action='store_true')
args, _ = p.parse_known_args()
model,meta=train_model(args.n,args.passes,args.epochs,args.train_size,args.val_size,args.batch,args.lr,args.seed,args.device)
if args.demo:
    rng=np.random.default_rng(1)
    schedule=meta['schedule']
    L=len(schedule)
    theta=rng.uniform(0,math.pi/2,size=L)
    phi=rng.uniform(0,2*math.pi,size=L)
    U=unitary_from_params(args.n,schedule,theta,phi)
    res=synthesize('givens_predictor.pt',U,refine=args.refine)
    print(json.dumps(res,indent=2))


In [34]:
model, meta = train_model(
    n=8, passes=2, epochs=10,
    train_size=40000, val_size=4000,
    batch=128, lr=1e-2, seed=0, device='cuda'
)


In [35]:
import numpy as np, math, json
rng = np.random.default_rng(1)
L = 16#len(meta['schedule'])
theta_true = rng.uniform(0, math.pi/2, size=L)
phi_true = rng.uniform(0, 2*math.pi, size=L)
print('theta = ',theta_true, '\n phi = ', phi_true)
U = unitary_from_params(meta['n'], meta['schedule'], theta_true, phi_true)
res = synthesize('givens_predictor.pt', U, refine=True)
# print(res)
print(json.dumps({k: res[k] for k in ['error_fro','fidelity']}, indent=2))
print(res['sequence'][:])


theta =  [0.804  1.493  0.2264 1.4901 0.4898 0.665  1.3002 0.6428 0.8633 0.0433
 1.1836 0.8453 0.5179 1.2385 0.4763 0.7124] 
 phi =  [0.8422 2.5328 1.2783 1.6482 4.7147 1.7619 3.0485 6.1622 6.0423 4.554
 3.4006 1.7398 1.0094 6.0942 3.2426 0.728 ]
{
  "error_fro": 0.04651983384421423,
  "fidelity": 0.9989179525295542
}
[{'pair': (0, 1), 'theta': 0.806848958979458, 'phi': 0.8649081953053227}, {'pair': (0, 2), 'theta': 1.419492420185325, 'phi': 2.564695963199446}, {'pair': (0, 3), 'theta': 0.21704903974166365, 'phi': 1.295170649341052}, {'pair': (0, 4), 'theta': 1.5683619014019317, 'phi': 5.058463348204238}, {'pair': (0, 5), 'theta': -0.5056356922701988, 'phi': 5.200416863077462}, {'pair': (0, 6), 'theta': 0.6251958217001425, 'phi': -1.0993202009371799}, {'pair': (0, 7), 'theta': 1.3065092052472647, 'phi': 0.3866110239607623}, {'pair': (0, 1), 'theta': -0.6424136480184369, 'phi': 0.32946390263442515}, {'pair': (0, 2), 'theta': -0.8561109229702702, 'phi': 0.20672729648353366}, {'pair': (0,

In [36]:
theta_pred = np.array([s['theta'] for s in res['sequence']])
phi_pred = np.array([s['phi'] for s in res['sequence']])

U_true_raw = unitary_from_params(meta['n'], meta['schedule'], theta_true, phi_true)
U_pred_raw = unitary_from_params(meta['n'], meta['schedule'], theta_pred, phi_pred)

U_true = canon_phase(U_true_raw)
U_pred = canon_phase(U_pred_raw)

np.set_printoptions(precision=4, suppress=True)

print("U_true_raw:\n", U_true_raw)
print("U_pred_raw:\n", U_pred_raw)
print("U_true_canon:\n", U_true)
print("U_pred_canon:\n", U_pred)
print("U_true_canon - U_pred_canon:\n", U_true - U_pred)

err = np.linalg.norm(U_true - U_pred, 'fro')/np.linalg.norm(U_true, 'fro')
fid = np.abs(np.trace(U_true.conj().T @ U_pred))/U_true.shape[0]
print("fro_error_canon =", float(err))
print("fidelity_canon  =", float(fid))


U_true_raw:
 [[ 0.028 +0.0209j -0.0427-0.0313j -0.0295-0.1818j -0.0415-0.0044j
   0.0812-0.6616j  0.031 +0.1206j -0.0131-0.6731j -0.2128+0.0442j]
 [ 0.3842-0.43j    0.5551-0.0004j  0.0077-0.0041j -0.0003-0.002j
   0.0219-0.1087j -0.0073+0.0589j  0.0304-0.0941j  0.5774+0.0162j]
 [-0.6145-0.0401j -0.2583+0.425j   0.0588-0.0032j -0.0001-0.002j
   0.0352-0.1069j -0.0145+0.0584j  0.042 -0.0911j  0.5798+0.0864j]
 [-0.0167-0.0003j  0.0122-0.0111j -0.0696-0.2148j  0.9736-0.j
   0.0041+0.001j  -0.0022-0.0004j  0.0035+0.0013j -0.0014+0.0216j]
 [-0.3972-0.3163j  0.5119+0.3335j -0.1949-0.2977j -0.0938-0.0038j
   0.0465+0.0876j -0.0121-0.0461j  0.0054+0.0792j -0.4353+0.1599j]
 [ 0.0647-0.1112j -0.0855+0.1786j -0.5377+0.4082j  0.059 +0.1438j
   0.2882-0.0298j  0.5998-0.0024j -0.0242-0.0005j -0.0366-0.1367j]
 [ 0.0542-0.0204j -0.0829+0.035j  -0.3191-0.074j  -0.0355+0.0659j
  -0.5805-0.0956j  0.0096-0.262j   0.676 -0.0073j  0.028 -0.0554j]
 [ 0.0596+0.0738j -0.0992-0.1049j -0.1248-0.4543j -0.1083+0.00

# Predicting a single rotation with angles and phases

In [38]:
import math, json, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def givens_matrix(n,i,j,theta,phi):
    G=np.eye(n,dtype=np.complex128)
    c=np.cos(theta); s=np.sin(theta); e=np.exp(1j*phi)
    G[i,i]=c; G[j,j]=c; G[i,j]=-e*s; G[j,i]=np.conj(e)*s
    return G

def unitary_from_pair(n,pair,theta,phi):
    return givens_matrix(n,pair[0],pair[1],theta,phi)

def canon_phase(U):
    a=np.angle(np.linalg.det(U))/U.shape[0]
    return U*np.exp(-1j*a)

def vec_features(U):
    return np.concatenate([U.real.reshape(-1),U.imag.reshape(-1)],0).astype(np.float32)

def pack_angles(theta,phi):
    return np.array([np.cos(theta),np.sin(theta),np.cos(phi),np.sin(phi)],dtype=np.float32)

def unpack_angles(y):
    cth,sth,cph,sph=y
    theta=np.arctan2(sth,cth)
    phi=np.arctan2(sph,cph)
    theta=np.mod(theta,math.pi/2)
    phi=np.mod(phi,2*math.pi)
    return theta,phi

def fidelity(U,Uhat):
    n=U.shape[0]
    return float(np.abs(np.trace(U.conj().T@Uhat))/n)

class PairAngleDataset(Dataset):
    def __init__(self,n,pairs,size,seed=0):
        self.n=n; self.pairs=pairs; self.size=size
        rng=np.random.default_rng(seed)
        X=[]; y_pair=[]; y_ang=[]
        for _ in range(size):
            k=rng.integers(0,len(pairs))
            th=rng.uniform(0,math.pi/2); ph=rng.uniform(0,2*math.pi)
            U=unitary_from_pair(n,pairs[k],th,ph)
            U=canon_phase(U)
            X.append(vec_features(U))
            y_pair.append(k)
            y_ang.append(pack_angles(th,ph))
        self.X=np.stack(X,0); self.y_pair=np.array(y_pair,dtype=np.int64); self.y_ang=np.stack(y_ang,0)
    def __len__(self): return self.size
    def __getitem__(self,idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y_pair[idx]), torch.from_numpy(self.y_ang[idx])

class Trunk(nn.Module):
    def __init__(self,in_dim,h=512):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(in_dim,h),nn.ReLU(),nn.Linear(h,h),nn.ReLU())
    def forward(self,x): return self.net(x)

class PairAngleNet(nn.Module):
    def __init__(self,in_dim,num_pairs,h=512):
        super().__init__()
        self.trunk=Trunk(in_dim,h)
        self.head_pair=nn.Linear(h,num_pairs)
        self.head_angles=nn.Linear(h,num_pairs*4)
        self.num_pairs=num_pairs
    def forward(self,x):
        z=self.trunk(x)
        logits=self.head_pair(z)
        ang=self.head_angles(z).view(-1,self.num_pairs,4)
        return logits, ang

@torch.no_grad()
def eval_fid(model,loader,n,pairs,device):
    model.eval()
    N=0; s=0.0; ltot=0.0
    for xb,yb_pair,yb_ang in loader:
        xb=xb.to(device); yb_pair=yb_pair.to(device); yb_ang=yb_ang.to(device)
        logits, angs=model(xb)
        ce=F.cross_entropy(logits,yb_pair)
        idx=yb_pair.view(-1,1,1).expand(-1,1,4)
        sel=torch.gather(angs,1,idx).squeeze(1)
        mse=F.mse_loss(sel,yb_ang)
        l=ce+mse
        preds=logits.argmax(dim=1)
        sel_pred=torch.gather(angs,1,preds.view(-1,1,1).expand(-1,1,4)).squeeze(1).cpu().numpy()
        y_true=yb_ang.cpu().numpy()
        pairs_pred=preds.cpu().numpy()
        for b in range(xb.size(0)):
            theta_t,phi_t=unpack_angles(y_true[b])
            theta_p,phi_p=unpack_angles(sel_pred[b])
            U_t=canon_phase(unitary_from_pair(n,pairs[yb_pair[b].item()],theta_t,phi_t))
            U_p=canon_phase(unitary_from_pair(n,pairs[pairs_pred[b]],theta_p,phi_p))
            s+=fidelity(U_t,U_p); N+=1
        ltot+=l.item()*xb.size(0)
    return ltot/max(1,N), s/max(1,N)

def train_pair_selector(n=4,pairs=[(0,1),(0,2),(0,3)],epochs=15,train_size=16000,val_size=2000,batch=256,lr=1e-3,seed=0,device='cpu'):
    torch.manual_seed(seed); np.random.seed(seed)
    train_ds=PairAngleDataset(n,pairs,train_size,seed)
    val_ds=PairAngleDataset(n,pairs,val_size,seed+1)
    in_dim=2*n*n; P=len(pairs)
    model=PairAngleNet(in_dim,P,512).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=lr)
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True,drop_last=True)
    val_loader=DataLoader(val_ds,batch_size=batch,shuffle=False)
    best_fid=-1.0; best_state=None
    for ep in range(1,epochs+1):
        model.train()
        tot=0.0; cnt=0
        for xb,yb_pair,yb_ang in train_loader:
            xb=xb.to(device); yb_pair=yb_pair.to(device); yb_ang=yb_ang.to(device)
            logits, angs=model(xb)
            ce=F.cross_entropy(logits,yb_pair)
            idx=yb_pair.view(-1,1,1).expand(-1,1,4)
            sel=torch.gather(angs,1,idx).squeeze(1)
            mse=F.mse_loss(sel,yb_ang)
            loss=ce+mse
            opt.zero_grad(); loss.backward(); opt.step()
            tot+=loss.item()*xb.size(0); cnt+=xb.size(0)
        train_loss=tot/cnt
        vloss,vfid=eval_fid(model,val_loader,n,pairs,device)
        tloss,tfid=eval_fid(model,DataLoader(train_ds,batch_size=batch,shuffle=False),n,pairs,device)
        if vfid>best_fid:
            best_fid=vfid
            best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}
        print(f"epoch {ep:02d} | train_loss {train_loss:.5f} val_loss {vloss:.5f} | train_fid {tfid:.6f} val_fid {vfid:.6f}")
    model.load_state_dict(best_state)
    meta={'n':n,'pairs':pairs}
    torch.save({'state_dict':model.state_dict(),'meta':meta},'pair_angle_selector.pt')
    return model,meta

@torch.no_grad()
def predict_one(model,meta,U):
    n=meta['n']; pairs=meta['pairs']
    x=torch.from_numpy(vec_features(canon_phase(U))).unsqueeze(0).float()
    logits, angs=model(x)
    k=int(torch.argmax(logits,dim=1).item())
    y=angs[0,k].numpy()
    th,ph=unpack_angles(y)
    Uhat=canon_phase(unitary_from_pair(n,pairs[k],th,ph))
    return {'pair':pairs[k],'theta':float(th),'phi':float(ph),'U_pred':Uhat}

def demo_once(n=4,pairs=[(0,1),(0,2),(0,3)],seed=123):
    rng=np.random.default_rng(seed)
    k=int(rng.integers(0,len(pairs)))
    th=rng.uniform(0,math.pi/2); ph=rng.uniform(0,2*math.pi)
    U=canon_phase(unitary_from_pair(n,pairs[k],th,ph))
    ckpt=torch.load('pair_angle_selector.pt',map_location='cpu')
    in_dim=2*n*n; P=len(pairs)
    model=PairAngleNet(in_dim,P,512)
    model.load_state_dict(ckpt['state_dict']); model.eval()
    pred=predict_one(model,{'n':n,'pairs':pairs},U)
    fid_val=fidelity(U,pred['U_pred'])
    return {'true_pair':pairs[k],'true_theta':float(th),'true_phi':float(ph),'pred_pair':pred['pair'],'pred_theta':pred['theta'],'pred_phi':pred['phi'],'fidelity':fid_val,'U_true':U,'U_pred':pred['U_pred']}


In [52]:
model, meta = train_pair_selector(n=4, pairs=[(0,1),(0,2),(0,3)], epochs=5, train_size=30000, val_size=3000, batch=256, lr=1e-4, seed=0, device='cuda')

epoch 01 | train_loss 1.13806 val_loss 0.74089 | train_fid 0.984408 val_fid 0.984212
epoch 02 | train_loss 0.50484 val_loss 0.34289 | train_fid 0.998844 val_fid 0.998802
epoch 03 | train_loss 0.27286 val_loss 0.21350 | train_fid 0.999095 val_fid 0.999057
epoch 04 | train_loss 0.18329 val_loss 0.15492 | train_fid 0.999263 val_fid 0.999234
epoch 05 | train_loss 0.13798 val_loss 0.12163 | train_fid 0.999576 val_fid 0.999568


In [53]:
res = demo_once(n=4, pairs=[(0,1),(0,2),(0,3)], seed=3)
np.set_printoptions(precision=4, suppress=True)
print(json.dumps({k:res[k] for k in ['true_pair','true_theta','true_phi','pred_pair','pred_theta','pred_phi','fidelity']}, indent=2))
print('U_true:\n', res['U_true'])
print('U_pred:\n', res['U_pred'])

{
  "true_pair": [
    0,
    3
  ],
  "true_theta": 0.37198107390759205,
  "true_phi": 5.034555946803014,
  "pred_pair": [
    0,
    3
  ],
  "pred_theta": 0.3677932918071747,
  "pred_phi": 5.039814472198486,
  "fidelity": 0.9999947154523011
}
U_true:
 [[ 0.9316+0.j      0.    +0.j      0.    +0.j     -0.1151+0.3448j]
 [ 0.    +0.j      1.    +0.j      0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      1.    +0.j      0.    +0.j    ]
 [ 0.1151+0.3448j  0.    +0.j      0.    +0.j      0.9316+0.j    ]]
U_pred:
 [[ 0.9331-0.j      0.    +0.j      0.    +0.j     -0.1156+0.3405j]
 [ 0.    +0.j      1.    -0.j      0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      1.    -0.j      0.    +0.j    ]
 [ 0.1156+0.3405j  0.    +0.j      0.    +0.j      0.9331-0.j    ]]


# Predicts single or double rotations (theta restricted from 0 to pi)

## Two rotatiosn fixed

In [54]:
import math, json, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def givens_matrix(n,i,j,theta,phi):
    G=np.eye(n,dtype=np.complex128)
    c=np.cos(theta); s=np.sin(theta); e=np.exp(1j*phi)
    G[i,i]=c; G[j,j]=c; G[i,j]=-e*s; G[j,i]=np.conj(e)*s
    return G

def apply_two(n,p1,th1,ph1,p2,th2,ph2):
    U=np.eye(n,dtype=np.complex128)
    U=givens_matrix(n,p1[0],p1[1],th1,ph1)@U
    U=givens_matrix(n,p2[0],p2[1],th2,ph2)@U
    return U

def canon_phase(U):
    a=np.angle(np.linalg.det(U))/U.shape[0]
    return U*np.exp(-1j*a)

def vec_features(U):
    return np.concatenate([U.real.reshape(-1),U.imag.reshape(-1)],0).astype(np.float32)

def pack_angles(theta,phi):
    return np.array([np.cos(theta),np.sin(theta),np.cos(phi),np.sin(phi)],dtype=np.float32)

def unpack_angles(y):
    cth,sth,cph,sph=y
    theta=np.arctan2(sth,cth); phi=np.arctan2(sph,cph)
    theta=np.mod(theta,math.pi/2); phi=np.mod(phi,2*math.pi)
    return theta,phi

def fidelity(U,Uhat):
    n=U.shape[0]
    return float(np.abs(np.trace(U.conj().T@Uhat))/n)

class Seq2Dataset(Dataset):
    def __init__(self,n,pairs,size,seed=0):
        self.n=n; self.pairs=pairs; self.size=size
        rng=np.random.default_rng(seed)
        X=[]; y_p1=[]; y_p2=[]; y_a1=[]; y_a2=[]
        for _ in range(size):
            k1=rng.integers(0,len(pairs)); k2=rng.integers(0,len(pairs))
            th1=rng.uniform(0,math.pi/2); ph1=rng.uniform(0,2*math.pi)
            th2=rng.uniform(0,math.pi/2); ph2=rng.uniform(0,2*math.pi)
            U=apply_two(n,pairs[k1],th1,ph1,pairs[k2],th2,ph2)
            U=canon_phase(U)
            X.append(vec_features(U))
            y_p1.append(k1); y_p2.append(k2)
            y_a1.append(pack_angles(th1,ph1)); y_a2.append(pack_angles(th2,ph2))
        self.X=np.stack(X,0).astype(np.float32)
        self.y_p1=np.array(y_p1,dtype=np.int64)
        self.y_p2=np.array(y_p2,dtype=np.int64)
        self.y_a1=np.stack(y_a1,0).astype(np.float32)
        self.y_a2=np.stack(y_a2,0).astype(np.float32)
    def __len__(self): return self.size
    def __getitem__(self,idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y_p1[idx]), torch.tensor(self.y_p2[idx]), torch.from_numpy(self.y_a1[idx]), torch.from_numpy(self.y_a2[idx])

class Trunk(nn.Module):
    def __init__(self,in_dim,h=512):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(in_dim,h),nn.ReLU(),nn.Linear(h,h),nn.ReLU())
    def forward(self,x): return self.net(x)

class Seq2Net(nn.Module):
    def __init__(self,in_dim,num_pairs,h=512):
        super().__init__()
        self.trunk=Trunk(in_dim,h)
        self.head_p1=nn.Linear(h,num_pairs)
        self.head_p2=nn.Linear(h,num_pairs)
        self.head_ang=nn.Linear(h,2*num_pairs*4)
        self.num_pairs=num_pairs
    def forward(self,x):
        z=self.trunk(x)
        l1=self.head_p1(z)
        l2=self.head_p2(z)
        a=self.head_ang(z).view(-1,2,self.num_pairs,4)
        return l1,l2,a

@torch.no_grad()
def eval_fid(model,loader,n,pairs,device):
    model.eval()
    N=0; sf=0.0; ltot=0.0
    for xb,y_p1,y_p2,y_a1,y_a2 in loader:
        xb=xb.to(device); y_p1=y_p1.to(device); y_p2=y_p2.to(device); y_a1=y_a1.to(device); y_a2=y_a2.to(device)
        l1,l2,angs=model(xb)
        ce1=F.cross_entropy(l1,y_p1); ce2=F.cross_entropy(l2,y_p2)
        a1=angs[:,0]; a2=angs[:,1]
        idx1=y_p1.view(-1,1,1).expand(-1,1,4)
        idx2=y_p2.view(-1,1,1).expand(-1,1,4)
        sel1=torch.gather(a1,1,idx1).squeeze(1)
        sel2=torch.gather(a2,1,idx2).squeeze(1)
        mse=F.mse_loss(sel1,y_a1)+F.mse_loss(sel2,y_a2)
        loss=ce1+ce2+mse
        p1=l1.argmax(dim=1); p2=l2.argmax(dim=1)
        pa1=torch.gather(a1,1,p1.view(-1,1,1).expand(-1,1,4)).squeeze(1).cpu().numpy()
        pa2=torch.gather(a2,1,p2.view(-1,1,1).expand(-1,1,4)).squeeze(1).cpu().numpy()
        ta1=y_a1.cpu().numpy(); ta2=y_a2.cpu().numpy()
        p1=p1.cpu().numpy(); p2=p2.cpu().numpy(); y_p1_np=y_p1.cpu().numpy(); y_p2_np=y_p2.cpu().numpy()
        for b in range(xb.size(0)):
            th1_t,ph1_t=unpack_angles(ta1[b]); th2_t,ph2_t=unpack_angles(ta2[b])
            th1_p,ph1_p=unpack_angles(pa1[b]); th2_p,ph2_p=unpack_angles(pa2[b])
            U_t=canon_phase(apply_two(n,pairs[y_p1_np[b]],th1_t,ph1_t,pairs[y_p2_np[b]],th2_t,ph2_t))
            U_p=canon_phase(apply_two(n,pairs[p1[b]],th1_p,ph1_p,pairs[p2[b]],th2_p,ph2_p))
            sf+=fidelity(U_t,U_p); N+=1
        ltot+=loss.item()*xb.size(0)
    return ltot/max(1,N), sf/max(1,N)

def train_seq2_selector(n=4,pairs=[(0,1),(0,2),(0,3)],epochs=20,train_size=24000,val_size=3000,batch=256,lr=1e-3,seed=0,device='cpu'):
    torch.manual_seed(seed); np.random.seed(seed)
    train_ds=Seq2Dataset(n,pairs,train_size,seed)
    val_ds=Seq2Dataset(n,pairs,val_size,seed+1)
    in_dim=2*n*n; P=len(pairs)
    model=Seq2Net(in_dim,P,512).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=lr)
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True,drop_last=True)
    val_loader=DataLoader(val_ds,batch_size=batch,shuffle=False)
    best_fid=-1.0; best_state=None
    for ep in range(1,epochs+1):
        model.train()
        tot=0.0; cnt=0
        for xb,y_p1,y_p2,y_a1,y_a2 in train_loader:
            xb=xb.to(device); y_p1=y_p1.to(device); y_p2=y_p2.to(device); y_a1=y_a1.to(device); y_a2=y_a2.to(device)
            l1,l2,angs=model(xb)
            ce1=F.cross_entropy(l1,y_p1); ce2=F.cross_entropy(l2,y_p2)
            a1=angs[:,0]; a2=angs[:,1]
            idx1=y_p1.view(-1,1,1).expand(-1,1,4)
            idx2=y_p2.view(-1,1,1).expand(-1,1,4)
            sel1=torch.gather(a1,1,idx1).squeeze(1)
            sel2=torch.gather(a2,1,idx2).squeeze(1)
            mse=F.mse_loss(sel1,y_a1)+F.mse_loss(sel2,y_a2)
            loss=ce1+ce2+mse
            opt.zero_grad(); loss.backward(); opt.step()
            tot+=loss.item()*xb.size(0); cnt+=xb.size(0)
        train_loss=tot/cnt
        vloss,vfid=eval_fid(model,val_loader,n,pairs,device)
        tloss,tfid=eval_fid(model,DataLoader(train_ds,batch_size=batch,shuffle=False),n,pairs,device)
        if vfid>best_fid:
            best_fid=vfid
            best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}
        print(f"epoch {ep:02d} | train_loss {train_loss:.5f} val_loss {vloss:.5f} | train_fid {tfid:.6f} val_fid {vfid:.6f}")
    model.load_state_dict(best_state)
    meta={'n':n,'pairs':pairs}
    torch.save({'state_dict':model.state_dict(),'meta':meta},'seq2_selector.pt')
    return model,meta

@torch.no_grad()
def predict_seq2(model,meta,U):
    n=meta['n']; pairs=meta['pairs']
    x=torch.from_numpy(vec_features(canon_phase(U))).unsqueeze(0).float()
    l1,l2,angs=model(x)
    k1=int(torch.argmax(l1,dim=1).item())
    k2=int(torch.argmax(l2,dim=1).item())
    y1=angs[0,0,k1].numpy(); y2=angs[0,1,k2].numpy()
    th1,ph1=unpack_angles(y1); th2,ph2=unpack_angles(y2)
    Uhat=canon_phase(apply_two(n,pairs[k1],th1,ph1,pairs[k2],th2,ph2))
    return [{'pair':pairs[k1],'theta':float(th1),'phi':float(ph1)},{'pair':pairs[k2],'theta':float(th2),'phi':float(ph2)}],Uhat

def demo_seq2(n=4,pairs=[(0,1),(0,2),(0,3)],seed=123):
    rng=np.random.default_rng(seed)
    k1=int(rng.integers(0,len(pairs))); k2=int(rng.integers(0,len(pairs)))
    th1=rng.uniform(0,math.pi/2); ph1=rng.uniform(0,2*math.pi)
    th2=rng.uniform(0,math.pi/2); ph2=rng.uniform(0,2*math.pi)
    U=canon_phase(apply_two(n,pairs[k1],th1,ph1,pairs[k2],th2,ph2))
    ckpt=torch.load('seq2_selector.pt',map_location='cpu')
    in_dim=2*n*n; P=len(pairs)
    model=Seq2Net(in_dim,P,512); model.load_state_dict(ckpt['state_dict']); model.eval()
    seq_pred,U_pred=predict_seq2(model,{'n':n,'pairs':pairs},U)
    fid_val=fidelity(U,U_pred)
    return {'true_seq':[{'pair':pairs[k1],'theta':float(th1),'phi':float(ph1)},{'pair':pairs[k2],'theta':float(th2),'phi':float(ph2)}],'pred_seq':seq_pred,'fidelity':fid_val,'U_true':U,'U_pred':U_pred}


In [66]:
model, meta = train_seq2_selector(n=4, pairs=[(0,1),(0,2),(0,3)], epochs=10, train_size=60000, val_size=6000, batch=256, lr=3*1e-3, seed=0, device='cuda')
res = demo_seq2(n=4, pairs=[(0,1),(0,2),(0,3)], seed=7)
np.set_printoptions(precision=4, suppress=True)
print(json.dumps({k:res[k] for k in ['true_seq','pred_seq','fidelity']}, indent=2))
print('U_true:\n', res['U_true'])
print('U_pred:\n', res['U_pred'])


epoch 01 | train_loss 0.91891 val_loss 0.42568 | train_fid 0.967728 val_fid 0.967711
epoch 02 | train_loss 0.35398 val_loss 0.31198 | train_fid 0.975277 val_fid 0.975238
epoch 03 | train_loss 0.28857 val_loss 0.27858 | train_fid 0.969514 val_fid 0.969717
epoch 04 | train_loss 0.25473 val_loss 0.23308 | train_fid 0.976999 val_fid 0.976556
epoch 05 | train_loss 0.23588 val_loss 0.24008 | train_fid 0.970834 val_fid 0.973420
epoch 06 | train_loss 0.22239 val_loss 0.21252 | train_fid 0.970548 val_fid 0.969789
epoch 07 | train_loss 0.21217 val_loss 0.20606 | train_fid 0.971939 val_fid 0.970957
epoch 08 | train_loss 0.20186 val_loss 0.20063 | train_fid 0.981481 val_fid 0.982207
epoch 09 | train_loss 0.19498 val_loss 0.19468 | train_fid 0.980833 val_fid 0.980720
epoch 10 | train_loss 0.19106 val_loss 0.19006 | train_fid 0.976722 val_fid 0.975669
{
  "true_seq": [
    {
      "pair": [
        0,
        3
      ],
      "theta": 1.4093401429126966,
      "phi": 4.873776931938056
    },
    {
 

In [65]:
res = demo_seq2(n=4, pairs=[(0,1),(0,2),(0,3)], seed=1)
np.set_printoptions(precision=4, suppress=True)
print(json.dumps({k:res[k] for k in ['true_seq','pred_seq','fidelity']}, indent=2))
print('U_true:\n', res['U_true'])
print('U_pred:\n', res['U_pred'])

{
  "true_seq": [
    {
      "pair": [
        0,
        2
      ],
      "theta": 1.492984882940679,
      "phi": 0.9057815605287021
    },
    {
      "pair": [
        0,
        2
      ],
      "theta": 1.490135066979192,
      "phi": 1.9592947975887585
    }
  ],
  "pred_seq": [
    {
      "pair": [
        0,
        2
      ],
      "theta": 1.3921685218811035,
      "phi": 2.736405611038208
    },
    {
      "pair": [
        0,
        2
      ],
      "theta": 1.4535804986953735,
      "phi": 1.4361897706985474
    }
  ],
  "fidelity": 0.16564087345883183
}
U_true:
 [[-0.4852-0.8637j  0.    +0.j     -0.0202-0.1349j  0.    +0.j    ]
 [ 0.    +0.j      1.    +0.j      0.    +0.j      0.    +0.j    ]
 [ 0.0202-0.1349j  0.    +0.j     -0.4852+0.8637j  0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      0.    +0.j      1.    +0.j    ]]
U_pred:
 [[-0.2405+0.9418j  0.    +0.j      0.0821-0.2202j  0.    +0.j    ]
 [ 0.    +0.j      1.    +0.j      0.    +0.j      0.    +0.j    ]


## One and Two rotations 

In [105]:
import math, json, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def givens_matrix(n,i,j,theta,phi):
    G=np.eye(n,dtype=np.complex128)
    c=np.cos(theta); s=np.sin(theta); e=np.exp(1j*phi)
    G[i,i]=c; G[j,j]=c; G[i,j]=-e*s; G[j,i]=np.conj(e)*s
    return G

def apply_one(n,p1,th1,ph1):
    U=np.eye(n,dtype=np.complex128)
    U=givens_matrix(n,p1[0],p1[1],th1,ph1)@U
    return U

def apply_two(n,p1,th1,ph1,p2,th2,ph2):
    U=np.eye(n,dtype=np.complex128)
    U=givens_matrix(n,p1[0],p1[1],th1,ph1)@U
    U=givens_matrix(n,p2[0],p2[1],th2,ph2)@U
    return U

def canon_phase(U):
    a=np.angle(np.linalg.det(U))/U.shape[0]
    return U*np.exp(-1j*a)

def vec_features(U):
    return np.concatenate([U.real.reshape(-1),U.imag.reshape(-1)],0).astype(np.float32)

def pack_angles(theta,phi):
    return np.array([np.cos(theta),np.sin(theta),np.cos(phi),np.sin(phi)],dtype=np.float32)

def unpack_angles(y):
    cth,sth,cph,sph=y
    theta=np.arctan2(sth,cth); phi=np.arctan2(sph,cph)
    theta=np.mod(theta,math.pi/2); phi=np.mod(phi,2*math.pi)
    return theta,phi

def fidelity(U,Uhat):
    n=U.shape[0]
    return float(np.abs(np.trace(U.conj().T@Uhat))/n)

class SeqMixDataset(Dataset):
    def __init__(self,n,pairs,size,p_one=0.5,seed=0):
        self.n=n; self.pairs=pairs; self.size=size; self.p_one=p_one; self.STOP=len(pairs)
        rng=np.random.default_rng(seed)
        X=[]; y_p1=[]; y_p2=[]; y_a1=[]; y_a2=[]
        for _ in range(size):
            k1=rng.integers(0,len(pairs))
            th1=rng.uniform(0,math.pi/2); ph1=rng.uniform(0,2*math.pi)
            if rng.random()<p_one:
                U=apply_one(n,pairs[k1],th1,ph1)
                k2=self.STOP; th2=0.0; ph2=0.0
            else:
                k2=rng.integers(0,len(pairs))
                th2=rng.uniform(0,math.pi/2); ph2=rng.uniform(0,2*math.pi)
                U=apply_two(n,pairs[k1],th1,ph1,pairs[k2],th2,ph2)
            U=canon_phase(U)
            X.append(vec_features(U))
            y_p1.append(k1); y_p2.append(k2)
            y_a1.append(pack_angles(th1,ph1))
            y_a2.append(pack_angles(th2,ph2) if k2!=self.STOP else np.array([1.,0.,1.,0.],dtype=np.float32))
        self.X=np.stack(X,0).astype(np.float32)
        self.y_p1=np.array(y_p1,dtype=np.int64)
        self.y_p2=np.array(y_p2,dtype=np.int64)
        self.y_a1=np.stack(y_a1,0).astype(np.float32)
        self.y_a2=np.stack(y_a2,0).astype(np.float32)
    def __len__(self): return self.size
    def __getitem__(self,idx):
        return (torch.from_numpy(self.X[idx]),
                torch.tensor(self.y_p1[idx]),
                torch.tensor(self.y_p2[idx]),
                torch.from_numpy(self.y_a1[idx]),
                torch.from_numpy(self.y_a2[idx]))

class Trunk(nn.Module):
    def __init__(self,in_dim,h=512):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(in_dim,h),nn.ReLU(),nn.Linear(h,h),nn.ReLU())
    def forward(self,x): return self.net(x)

class Seq2StopNet(nn.Module):
    def __init__(self,in_dim,num_pairs,h=512):
        super().__init__()
        self.trunk=Trunk(in_dim,h)
        self.head_p1=nn.Linear(h,num_pairs)
        self.head_p2=nn.Linear(h,num_pairs+1)
        self.head_ang=nn.Linear(h,2*num_pairs*4)
        self.num_pairs=num_pairs
    def forward(self,x):
        z=self.trunk(x)
        l1=self.head_p1(z)
        l2=self.head_p2(z)
        a=self.head_ang(z).view(-1,2,self.num_pairs,4)
        return l1,l2,a

@torch.no_grad()
def eval_stats(model,loader,n,pairs,STOP,device):
    model.eval()
    N=0; sf=0.0; ltot=0.0; used_sum=0.0
    for xb,y_p1,y_p2,y_a1,y_a2 in loader:
        xb=xb.to(device); y_p1=y_p1.to(device); y_p2=y_p2.to(device); y_a1=y_a1.to(device); y_a2=y_a2.to(device)
        l1,l2,angs=model(xb)
        ce1=F.cross_entropy(l1,y_p1); ce2=F.cross_entropy(l2,y_p2)
        a1=angs[:,0]; a2=angs[:,1]
        idx1=y_p1.view(-1,1,1).expand(-1,1,4)
        sel1=torch.gather(a1,1,idx1).squeeze(1)
        mask=(y_p2!=STOP).float().unsqueeze(-1)
        idx2=torch.clamp(y_p2,0,STOP-1).view(-1,1,1).expand(-1,1,4)
        sel2=torch.gather(a2,1,idx2).squeeze(1)
        mse=F.mse_loss(sel1,y_a1)+F.mse_loss(sel2,y_a2,reduction='none')
        mse=(mse*mask).mean()
        loss=ce1+ce2+mse
        probs=torch.softmax(l2,dim=1)
        used=1.0-(probs[:,STOP]).mean().item()
        p1=l1.argmax(dim=1); p2=l2.argmax(dim=1)
        pa1=torch.gather(a1,1,p1.view(-1,1,1).expand(-1,1,4)).squeeze(1).cpu().numpy()
        pa2=torch.gather(a2,1,torch.clamp(p2,0,STOP-1).view(-1,1,1).expand(-1,1,4)).squeeze(1).cpu().numpy()
        ta1=y_a1.cpu().numpy(); ta2=y_a2.cpu().numpy()
        p1=p1.cpu().numpy(); p2=p2.cpu().numpy(); y_p1_np=y_p1.cpu().numpy(); y_p2_np=y_p2.cpu().numpy()
        for b in range(xb.size(0)):
            th1_t,ph1_t=unpack_angles(ta1[b])
            U_t=apply_one(n,pairs[y_p1_np[b]],th1_t,ph1_t)
            if y_p2_np[b]!=STOP:
                th2_t,ph2_t=unpack_angles(ta2[b])
                U_t=apply_two(n,pairs[y_p1_np[b]],th1_t,ph1_t,pairs[y_p2_np[b]],th2_t,ph2_t)
            U_t=canon_phase(U_t)
            th1_p,ph1_p=unpack_angles(pa1[b])
            U_p=apply_one(n,pairs[p1[b]],th1_p,ph1_p)
            if p2[b]!=STOP:
                th2_p,ph2_p=unpack_angles(pa2[b])
                U_p=apply_two(n,pairs[p1[b]],th1_p,ph1_p,pairs[p2[b]],th2_p,ph2_p)
            U_p=canon_phase(U_p)
            sf+=fidelity(U_t,U_p); N+=1
        ltot+=loss.item()*xb.size(0); used_sum+=used*xb.size(0)
    return ltot/max(1,N), sf/max(1,N), used_sum/max(1,N)

def train_seq2_minrot(n=4,pairs=[(0,1),(0,2),(0,3)],epochs=20,train_size=24000,val_size=3000,batch=256,lr=1e-3,p_one=0.5,lam_use=0.01,seed=0,device='cpu'):
    torch.manual_seed(seed); np.random.seed(seed)
    STOP=len(pairs)
    train_ds=SeqMixDataset(n,pairs,train_size,p_one=p_one,seed=seed)
    val_ds=SeqMixDataset(n,pairs,val_size,p_one=p_one,seed=seed+1)
    in_dim=2*n*n; P=len(pairs)
    model=Seq2StopNet(in_dim,P,512).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=lr)
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True,drop_last=True)
    val_loader=DataLoader(val_ds,batch_size=batch,shuffle=False)
    best_fid=-1.0; best_state=None
    for ep in range(1,epochs+1):
        model.train()
        tot=0.0; cnt=0
        for xb,y_p1,y_p2,y_a1,y_a2 in train_loader:
            xb=xb.to(device); y_p1=y_p1.to(device); y_p2=y_p2.to(device); y_a1=y_a1.to(device); y_a2=y_a2.to(device)
            l1,l2,angs=model(xb)
            ce1=F.cross_entropy(l1,y_p1); ce2=F.cross_entropy(l2,y_p2)
            a1=angs[:,0]; a2=angs[:,1]
            idx1=y_p1.view(-1,1,1).expand(-1,1,4)
            sel1=torch.gather(a1,1,idx1).squeeze(1)
            mask=(y_p2!=STOP).float().unsqueeze(-1)
            idx2=torch.clamp(y_p2,0,STOP-1).view(-1,1,1).expand(-1,1,4)
            sel2=torch.gather(a2,1,idx2).squeeze(1)
            mse=F.mse_loss(sel1,y_a1)+F.mse_loss(sel2,y_a2,reduction='none')
            mse=(mse*mask).mean()
            probs=torch.softmax(l2,dim=1)
            use_pen=1.0-(probs[:,STOP]).mean()
            loss=ce1+ce2+mse+lam_use*use_pen
            opt.zero_grad(); loss.backward(); opt.step()
            tot+=loss.item()*xb.size(0); cnt+=xb.size(0)
        train_loss=tot/cnt
        vloss,vfid,vused=eval_stats(model,val_loader,n,pairs,STOP,device)
        tloss,tfid,tused=eval_stats(model,DataLoader(train_ds,batch_size=batch,shuffle=False),n,pairs,STOP,device)
        if vfid>best_fid:
            best_fid=vfid
            best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}
        print(f"epoch {ep:02d} | train_loss {train_loss:.5f} val_loss {vloss:.5f} | train_fid {tfid:.6f} val_fid {vfid:.6f} | train_used {tused:.3f} val_used {vused:.3f}")
    model.load_state_dict(best_state)
    meta={'n':n,'pairs':pairs,'STOP':STOP}
    torch.save({'state_dict':model.state_dict(),'meta':meta},'seq2_minrot.pt')
    return model,meta

@torch.no_grad()
def predict_seq2_minrot(model,meta,U):
    n=meta['n']; pairs=meta['pairs']; STOP=meta['STOP']
    x=torch.from_numpy(vec_features(canon_phase(U))).unsqueeze(0).float()
    l1,l2,angs=model(x)
    k1=int(torch.argmax(l1,dim=1).item())
    k2=int(torch.argmax(l2,dim=1).item())
    y1=angs[0,0,k1].numpy()
    th1,ph1=unpack_angles(y1)
    seq=[{'pair':pairs[k1],'theta':float(th1),'phi':float(ph1)}]
    Uhat=apply_one(n,pairs[k1],th1,ph1)
    if k2!=STOP:
        y2=angs[0,1,k2 if k2<STOP else STOP-1].numpy()
        th2,ph2=unpack_angles(y2)
        seq.append({'pair':pairs[k2],'theta':float(th2),'phi':float(ph2)})
        Uhat=apply_two(n,pairs[k1],th1,ph1,pairs[k2],th2,ph2)
    Uhat=canon_phase(Uhat)
    return seq,Uhat


In [95]:
model, meta = train_seq2_minrot(n=4, pairs=[(0,1),(0,2),(0,3)], epochs=15, train_size=60000, val_size=10000, batch=256, lr=1e-3, p_one=0.6, lam_use=0.02, seed=0, device='cpu')

rng=np.random.default_rng(7)
pairs=[(0,1),(0,2),(0,3)]
k1=int(rng.integers(0,len(pairs)))
th1=rng.uniform(0,math.pi/2); ph1=rng.uniform(0,2*math.pi)
if rng.random()<0.5:
    U=canon_phase(apply_one(4,pairs[k1],th1,ph1))
else:
    k2=int(rng.integers(0,len(pairs)))
    th2=rng.uniform(0,math.pi/2); ph2=rng.uniform(0,2*math.pi)
    U=canon_phase(apply_two(4,pairs[k1],th1,ph1,pairs[k2],th2,ph2))

seq_pred,U_pred=predict_seq2_minrot(model,meta,U)
fid=fidelity(U,U_pred)
print(json.dumps({'pred_seq':seq_pred,'fid':fid}, indent=2))


epoch 01 | train_loss 1.27057 val_loss 0.74370 | train_fid 0.972146 val_fid 0.973039 | train_used 0.341 val_used 0.343
epoch 02 | train_loss 0.52166 val_loss 0.40670 | train_fid 0.982740 val_fid 0.983426 | train_used 0.433 val_used 0.435
epoch 03 | train_loss 0.34702 val_loss 0.29866 | train_fid 0.990710 val_fid 0.990824 | train_used 0.385 val_used 0.388
epoch 04 | train_loss 0.28176 val_loss 0.25096 | train_fid 0.990808 val_fid 0.991037 | train_used 0.395 val_used 0.398
epoch 05 | train_loss 0.24745 val_loss 0.24546 | train_fid 0.988825 val_fid 0.989606 | train_used 0.397 val_used 0.399
epoch 06 | train_loss 0.22000 val_loss 0.21045 | train_fid 0.986473 val_fid 0.986839 | train_used 0.394 val_used 0.396
epoch 07 | train_loss 0.20324 val_loss 0.19071 | train_fid 0.990806 val_fid 0.991205 | train_used 0.400 val_used 0.402
epoch 08 | train_loss 0.19467 val_loss 0.19008 | train_fid 0.988446 val_fid 0.989261 | train_used 0.398 val_used 0.401
epoch 09 | train_loss 0.17906 val_loss 0.17791 |

In [104]:
rng=np.random.default_rng(10)
pairs=[(0,1),(0,2),(0,3)]
k1=int(rng.integers(0,len(pairs)))
th1=rng.uniform(0,math.pi/2); ph1=rng.uniform(0,2*math.pi)
if rng.random()<0.5:
    U=canon_phase(apply_one(4,pairs[k1],th1,ph1))
else:
    k2=int(rng.integers(0,len(pairs)))
    th2=rng.uniform(0,math.pi/2); ph2=rng.uniform(0,2*math.pi)
    U=canon_phase(apply_two(4,pairs[k1],th1,ph1,pairs[k2],th2,ph2))

k1 = 2
k2 = 2
th1 = math.pi/4
th2 = math.pi/4
ph1 = 0
ph2 = 0

U=canon_phase(apply_two(4,pairs[k1],th1,ph1,pairs[k2],th2,ph2))
seq_pred,U_pred=predict_seq2_minrot(model,meta,U)
print("True U: \n", U)
print("Pred U: \n", U_pred)
fid=fidelity(U,U_pred)
print(json.dumps({'pred_seq':seq_pred,'fid':fid}, indent=2))


True U: 
 [[ 0.+0.j  0.+0.j  0.+0.j -1.+0.j]
 [ 0.+0.j  1.+0.j  0.+0.j  0.+0.j]
 [ 0.+0.j  0.+0.j  1.+0.j  0.+0.j]
 [ 1.+0.j  0.+0.j  0.+0.j  0.+0.j]]
Pred U: 
 [[ 0.0059+0.j      0.    +0.j      0.    +0.j     -0.9966+0.0824j]
 [ 0.    +0.j      1.    +0.j      0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      1.    +0.j      0.    +0.j    ]
 [ 0.9966+0.0824j  0.    +0.j      0.    +0.j      0.0059+0.j    ]]
{
  "pred_seq": [
    {
      "pair": [
        0,
        3
      ],
      "theta": 1.564945936203003,
      "phi": 6.200651168823242
    }
  ],
  "fid": 0.998289465904236
}


In [28]:
import math, json, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def givens_matrix(n,i,j,theta,phi):
    G=np.eye(n,dtype=np.complex128)
    c=np.cos(theta); s=np.sin(theta); e=np.exp(1j*phi)
    G[i,i]=c; G[j,j]=c; G[i,j]=-e*s; G[j,i]=np.conj(e)*s
    return G

def apply_seq(n,pairs,seq_pairs,seq_theta,seq_phi):
    U=np.eye(n,dtype=np.complex128)
    for k in range(len(seq_pairs)):
        i,j=seq_pairs[k]
        U=givens_matrix(n,i,j,seq_theta[k],seq_phi[k])@U
    return U

def canon_phase(U):
    a=np.angle(np.linalg.det(U))/U.shape[0]
    return U*np.exp(-1j*a)

def vec_features(U):
    return np.concatenate([U.real.reshape(-1),U.imag.reshape(-1)],0).astype(np.float32)

def pack_angles(theta,phi):
    return np.array([np.cos(theta),np.sin(theta),np.cos(phi),np.sin(phi)],dtype=np.float32)

def unpack_angles(y):
    cth,sth,cph,sph=y
    theta=np.arctan2(sth,cth); phi=np.arctan2(sph,cph)
    theta=np.mod(theta,math.pi/2); phi=np.mod(phi,2*math.pi)
    return theta,phi

def fidelity(U,Uhat):
    n=U.shape[0]
    return float(np.abs(np.trace(U.conj().T@Uhat))/n)

def make_pairs(n,topology='star'):
    if topology=='star':
        return [(0,j) for j in range(1,n)]
    if topology=='all':
        return [(i,j) for i in range(n) for j in range(i+1,n)]
    raise ValueError('unknown topology')

class SeqTDataset(Dataset):
    def __init__(self,n,depth,pairs,size,p_continue=0.4,seed=0):
        self.n=n; self.depth=depth; self.pairs=pairs; self.size=size; self.P=len(pairs); self.STOP=self.P
        rng=np.random.default_rng(seed)
        X=[]; Yp=[]; Ya=[]
        for _ in range(size):
            L=1
            for t in range(depth-1):
                if rng.random()<p_continue: L+=1
                else: break
            L=min(L,depth)
            ks=[]; ths=[]; phs=[]
            for t in range(L):
                k=int(rng.integers(0,self.P))
                th=rng.uniform(0,math.pi); ph=rng.uniform(0,2*math.pi)
                ks.append(k); ths.append(th); phs.append(ph)
            U=apply_seq(n,pairs,[pairs[k] for k in ks],np.array(ths),np.array(phs))
            U=canon_phase(U)
            X.append(vec_features(U))
            y_pairs=[ks[t] if t<L else self.STOP for t in range(depth)]
            y_angles=[pack_angles(ths[t],phs[t]) if t<L else np.array([1.,0.,1.,0.],dtype=np.float32) for t in range(depth)]
            Yp.append(y_pairs); Ya.append(y_angles)
        self.X=np.stack(X,0).astype(np.float32)
        self.Yp=np.array(Yp,dtype=np.int64)
        self.Ya=np.array(Ya,dtype=np.float32)
    def __len__(self): return self.size
    def __getitem__(self,idx):
        return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Yp[idx]), torch.from_numpy(self.Ya[idx])

class Trunk(nn.Module):
    def __init__(self,in_dim,h=1024):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(in_dim,h),nn.ReLU(),nn.Linear(h,h),nn.ReLU())
    def forward(self,x): return self.net(x)

class SeqTNet(nn.Module):
    def __init__(self,in_dim,num_pairs,depth,h=1024):
        super().__init__()
        self.trunk=Trunk(in_dim,h)
        self.head_pairs=nn.ModuleList([nn.Linear(h,num_pairs+1) for _ in range(depth)])
        self.head_angles=nn.ModuleList([nn.Linear(h,num_pairs*4) for _ in range(depth)])
        self.depth=depth; self.P=num_pairs
    def forward(self,x):
        z=self.trunk(x)
        logits=[]; angs=[]
        for t in range(self.depth):
            lt=self.head_pairs[t](z)
            at=self.head_angles[t](z).view(-1,self.P,4)
            logits.append(lt); angs.append(at)
        L=torch.stack(logits,dim=1)
        A=torch.stack(angs,dim=1)
        return L,A
import torch, numpy as np, math
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler

def _norm2(x,eps=1e-8):
    return x/torch.clamp(torch.linalg.norm(x,dim=-1,keepdim=True),min=eps)

def _canon_phase_batch(U):
    n=U.size(-1)
    d=torch.linalg.det(U)
    a=torch.angle(d)/n
    f=torch.exp(-1j*a).unsqueeze(-1).unsqueeze(-1)
    return U*f

def _build_G(n,i_idx,j_idx,cth,sth,cph,sph,active):
    device=cth.device
    B,T=cth.shape
    M=B*T
    i=i_idx.reshape(M); j=j_idx.reshape(M)
    c=cth.reshape(M); s=sth.reshape(M)
    e=(cph.reshape(M)+1j*sph.reshape(M)).to(torch.complex64)
    act=active.reshape(M)
    G=torch.eye(n,device=device,dtype=torch.complex64).unsqueeze(0).expand(M,n,n).clone()
    idx=torch.nonzero(act,as_tuple=False).squeeze(1)
    if idx.numel()>0:
        G[idx, i[idx], i[idx]] = c[idx].to(torch.complex64)
        G[idx, j[idx], j[idx]] = c[idx].to(torch.complex64)
        G[idx, i[idx], j[idx]] = -(e[idx]*s[idx])
        G[idx, j[idx], i[idx]] = torch.conj(e[idx])*s[idx]
    return G.view(B,T,n,n)

@torch.no_grad()
def eval_stats(model,loader,n,pairs,STOP,device,max_batches=None):
    model.eval()
    P=len(pairs)
    pi=torch.tensor([p[0] for p in pairs],device=device, dtype=torch.long)
    pj=torch.tensor([p[1] for p in pairs],device=device, dtype=torch.long)
    N=0; s_fid=0.0; s_used=0.0; s_loss=0.0; seen=0
    for xb,y_pairs,y_angles in loader:
        xb=xb.to(device, non_blocking=True).float()
        y_pairs=y_pairs.to(device, non_blocking=True).long()
        y_angles=y_angles.to(device, non_blocking=True).float()
        L,A=model(xb)
        depth=L.size(1)
        ce=0.0
        for t in range(depth):
            ce+=torch.nn.functional.cross_entropy(L[:,t,:],y_pairs[:,t])
        idx=torch.clamp(y_pairs,0,STOP-1).unsqueeze(-1).unsqueeze(-1).expand(-1,-1,1,4)
        sel=torch.gather(A,2,idx).squeeze(2)
        mask=(y_pairs!=STOP).float().unsqueeze(-1)
        mse=torch.nn.functional.mse_loss(sel,y_angles,reduction='none')
        mse=(mse*mask).mean()
        loss=ce/depth+mse
        s_loss+=loss.item()*xb.size(0)
        probs=torch.softmax(L,dim=-1)
        used=(1.0-probs[:,:,STOP]).mean().item()
        s_used+=used*xb.size(0)
        k_pred=L.argmax(dim=-1)
        k_true=y_pairs
        cp_sp=_norm2(sel[...,:2]); cph_sph=_norm2(sel[...,2:])
        cth_t=cp_sp[...,0]; sth_t=cp_sp[...,1]
        cph_t=cph_sph[...,0]; sph_t=cph_sph[...,1]
        active_t=(k_true!=STOP)
        ktc=torch.clamp(k_true,0,STOP-1)
        it=pi[ktc]; jt=pj[ktc]
        Gt=_build_G(n,it,jt,cth_t,sth_t,cph_t,sph_t,active_t)
        B=xb.size(0)
        U_t=torch.eye(n,device=device,dtype=torch.complex64).unsqueeze(0).expand(B,n,n).clone()
        for t in range(depth):
            U_t=torch.bmm(Gt[:,t],U_t)
        U_t=_canon_phase_batch(U_t)
        idxp=torch.clamp(k_pred,0,STOP-1).unsqueeze(-1).unsqueeze(-1).expand(-1,-1,1,4)
        selp=torch.gather(A,2,idxp).squeeze(2)
        cp_sp=_norm2(selp[...,:2]); cph_sph=_norm2(selp[...,2:])
        cth_p=cp_sp[...,0]; sth_p=cp_sp[...,1]
        cph_p=cph_sph[...,0]; sph_p=cph_sph[...,1]
        active_p=(k_pred!=STOP)
        itp=pi[torch.clamp(k_pred,0,STOP-1)]; jtp=pj[torch.clamp(k_pred,0,STOP-1)]
        Gp=_build_G(n,itp,jtp,cth_p,sth_p,cph_p,sph_p,active_p)
        U_p=torch.eye(n,device=device,dtype=torch.complex64).unsqueeze(0).expand(B,n,n).clone()
        for t in range(depth):
            U_p=torch.bmm(Gp[:,t],U_p)
        U_p=_canon_phase_batch(U_p)
        M=torch.matmul(torch.conj(U_t.transpose(-2,-1)),U_p)
        tr=torch.diagonal(M,dim1=-2,dim2=-1).sum(-1)
        fid=(tr.abs()/n).mean().item()
        s_fid+=fid*xb.size(0)
        N+=xb.size(0); seen+=1
        if max_batches is not None and seen>=max_batches: break
    return s_loss/max(1,N), s_fid/max(1,N), s_used/max(1,N)

def train_seqT_minrot(n=6,depth=5,topology='star',pairs=None,epochs=20,train_size=60000,val_size=8000,batch=4096,lr=2e-3,p_continue=0.4,lam_use=0.02,h=2048,seed=0,device='cuda',num_workers=0,val_max_batches=20):
    torch.manual_seed(seed); np.random.seed(seed)
    torch.backends.cudnn.benchmark=True
    if torch.cuda.is_available(): torch.set_float32_matmul_precision('high')
    if pairs is None: pairs=make_pairs(n,topology)
    P=len(pairs); STOP=P
    train_ds=SeqTDataset(n,depth,pairs,train_size,p_continue=p_continue,seed=seed)
    val_ds=SeqTDataset(n,depth,pairs,val_size,p_continue=p_continue,seed=seed+1)
    in_dim=2*n*n
    model=SeqTNet(in_dim,P,depth,h).to(device)
    opt=torch.optim.AdamW(model.parameters(),lr=lr)
    scaler=GradScaler(enabled=(device.startswith('cuda') and torch.cuda.is_available()))
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=(num_workers>0),pin_memory=True)
    val_loader=DataLoader(val_ds,batch_size=batch,shuffle=False,num_workers=num_workers,persistent_workers=False,pin_memory=True)
    best_fid=-1.0; best_state=None
    for ep in range(1,epochs+1):
        model.train()
        tot=0.0; cnt=0
        for xb,y_pairs,y_angles in train_loader:
            xb=xb.pin_memory().to(device, non_blocking=True).float()
            y_pairs=y_pairs.pin_memory().to(device, non_blocking=True).long()
            y_angles=y_angles.pin_memory().to(device, non_blocking=True).float()
            with autocast(enabled=(device.startswith('cuda') and torch.cuda.is_available())):
                L,A=model(xb)
                depth_=L.size(1)
                ce=0.0
                for t in range(depth_):
                    ce+=torch.nn.functional.cross_entropy(L[:,t,:],y_pairs[:,t])
                idx=torch.clamp(y_pairs,0,STOP-1).unsqueeze(-1).unsqueeze(-1).expand(-1,-1,1,4)
                sel=torch.gather(A,2,idx).squeeze(2)
                mask=(y_pairs!=STOP).float().unsqueeze(-1)
                mse=torch.nn.functional.mse_loss(sel,y_angles,reduction='none')
                mse=(mse*mask).mean()
                probs=torch.softmax(L,dim=-1)
                use_pen=(1.0-probs[:,:,STOP]).mean()
                loss=ce/depth_+mse+lam_use*use_pen
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
            tot+=loss.detach().float().item()*xb.size(0); cnt+=xb.size(0)
        train_loss=tot/cnt
        vloss,vfid,vused=eval_stats(model,val_loader,n,pairs,STOP,device,max_batches=val_max_batches)
        if vfid>best_fid:
            best_fid=vfid
            best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}
        print(f"epoch {ep:02d} | train_loss {train_loss:.5f} val_loss {vloss:.5f} | val_fid {vfid:.6f} | val_used {vused:.3f}")
    model.load_state_dict(best_state)
    meta={'n':n,'pairs':pairs,'depth':depth,'STOP':STOP,'h':h}
    torch.save({'state_dict':model.state_dict(),'meta':meta},'seqT_minrot.pt')
    return model,meta



@torch.no_grad()
def predict_seqT(model,meta,U):
    n=meta['n']; pairs=meta['pairs']; depth=meta['depth']; STOP=meta['STOP']
    x=torch.from_numpy(vec_features(canon_phase(U))).unsqueeze(0).float()
    L,A=model(x)
    L=L[0]; A=A[0]
    seq=[]; ths=[]; phs=[]
    for t in range(depth):
        k=int(torch.argmax(L[t]).item())
        if k==STOP: break
        y=A[t,k].numpy()
        th,ph=unpack_angles(y)
        seq.append({'pair':pairs[k],'theta':float(th),'phi':float(ph)})
        ths.append(th); phs.append(ph)
    Uhat=canon_phase(apply_seq(n,pairs,[s['pair'] for s in seq],np.array(ths) if ths else np.array([],dtype=float),np.array(phs) if phs else np.array([],dtype=float)))
    return seq,Uhat

def demo_seqT(n=6,depth=5,topology='star',pairs=None,seed=123):
    if pairs is None: pairs=make_pairs(n,topology)
    rng=np.random.default_rng(seed)
    L=4
    for t in range(depth-1):
        if rng.random()<0.2: L+=1
        else: break
    ks=[int(rng.integers(0,len(pairs))) for _ in range(L)]
    ths=[rng.uniform(0,math.pi/2) for _ in range(L)]
    phs=[rng.uniform(0,2*math.pi) for _ in range(L)]
    U=canon_phase(apply_seq(n,pairs,[pairs[k] for k in ks],np.array(ths),np.array(phs)))
    print(U)
    ckpt=torch.load('seqT_minrot.pt',map_load_map='cpu') if False else torch.load('seqT_minrot.pt',map_location='cpu')
    in_dim=2*n*n; P=len(pairs)
    model=SeqTNet(in_dim,P,depth,ckpt['meta'].get('h',512))
    model.load_state_dict(ckpt['state_dict']); model.eval()
    seq_pred,U_pred=predict_seqT(model,{'n':n,'pairs':pairs,'depth':depth,'STOP':P},U)
    print(U_pred)
    fid=fidelity(U,U_pred)
    return {'true_len':L,'true_seq':[{'pair':pairs[ks[i]],'theta':float(ths[i]),'phi':float(phs[i])} for i in range(L)],'pred_seq':seq_pred,'fid':fid,'U_true':U,'U_pred':U_pred}


In [44]:
n=6
model, meta = train_seqT_minrot(n=n, depth=5, topology='star', epochs=20, train_size=300000, val_size=8000, batch=4096, lr=6e-3, p_continue=0.4, lam_use=0.005, h=2048, seed=0, device='cuda', num_workers=0, val_max_batches=20)



  scaler=GradScaler(enabled=(device.startswith('cuda') and torch.cuda.is_available()))
  with autocast(enabled=(device.startswith('cuda') and torch.cuda.is_available())):


epoch 01 | train_loss 0.70021 val_loss 0.42421 | val_fid 0.832150 | val_used 0.335
epoch 02 | train_loss 0.35375 val_loss 0.31391 | val_fid 0.908742 | val_used 0.327
epoch 03 | train_loss 0.27987 val_loss 0.27102 | val_fid 0.925020 | val_used 0.323
epoch 04 | train_loss 0.24041 val_loss 0.24134 | val_fid 0.930382 | val_used 0.335
epoch 05 | train_loss 0.22443 val_loss 0.22520 | val_fid 0.939794 | val_used 0.330
epoch 06 | train_loss 0.20476 val_loss 0.21419 | val_fid 0.939902 | val_used 0.329
epoch 07 | train_loss 0.19318 val_loss 0.20350 | val_fid 0.943523 | val_used 0.332
epoch 08 | train_loss 0.18220 val_loss 0.20225 | val_fid 0.944295 | val_used 0.333
epoch 09 | train_loss 0.17848 val_loss 0.19692 | val_fid 0.944152 | val_used 0.328
epoch 10 | train_loss 0.16945 val_loss 0.19038 | val_fid 0.945609 | val_used 0.332
epoch 11 | train_loss 0.16543 val_loss 0.18780 | val_fid 0.946885 | val_used 0.331
epoch 12 | train_loss 0.16074 val_loss 0.18228 | val_fid 0.947518 | val_used 0.332
epoc

In [45]:
res = demo_seqT(n=n, depth=5, topology='star', seed=3123)
np.set_printoptions(precision=4, suppress=True)
print(json.dumps({k:res[k] for k in ['true_len','true_seq','pred_seq','fid']}, indent=2))

[[-0.3145-0.1473j  0.7214-0.1587j  0.    +0.j      0.5244+0.2426j
   0.    +0.j      0.    +0.j    ]
 [ 0.0305+0.9044j  0.2703+0.3287j  0.    +0.j      0.    +0.j
   0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      1.    -0.j      0.    +0.j
   0.    +0.j      0.    +0.j    ]
 [-0.062 +0.2379j -0.2035-0.4817j  0.    +0.j     -0.2021+0.7908j
   0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      0.    +0.j      0.    +0.j
   1.    -0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      0.    +0.j      0.    +0.j
   0.    +0.j      1.    -0.j    ]]
[[ 0.0748-0.0053j  0.2097+0.3682j  0.    +0.j     -0.1928+0.2275j
  -0.0959+0.8466j  0.    +0.j    ]
 [-0.4252+0.8882j  0.1743-0.j      0.    +0.j      0.    +0.j
   0.    +0.j      0.    +0.j    ]
 [ 0.    +0.j      0.    +0.j      1.    -0.j      0.    +0.j
   0.    +0.j      0.    +0.j    ]
 [ 0.0151+0.0179j -0.0541+0.1208j  0.    +0.j      0.9521+0.0679j
  -0.2311+0.1321j  0.    +0.j    ]
 [ 0.0285+0.1

In [47]:
ckpt = torch.load('seqT_minrot.pt', map_location='cpu')
pairs = ckpt['meta']['pairs']; n = ckpt['meta']['n']; depth = ckpt['meta']['depth']; h = ckpt['meta'].get('h',512)

model = SeqTNet(2*n*n, len(pairs), depth, h)
model.load_state_dict(ckpt['state_dict']); model.eval()

seq_true = [
    {'pair': (0,1), 'theta': math.pi/2, 'phi': 5.1},
    {'pair': (0,2), 'theta': math.pi/3, 'phi': 5.1},
    {'pair': (0,3), 'theta': math.pi/4, 'phi': 5.1},
    {'pair': (0,2), 'theta': math.pi/3, 'phi': 5.1},
    {'pair': (0,1), 'theta': math.pi/2, 'phi': 5.1},
]

seq_pairs = [s['pair'] for s in seq_true]
ths = np.array([s['theta'] for s in seq_true])
phs = np.array([s['phi'] for s in seq_true])

U = canon_phase(apply_seq(n, pairs, seq_pairs, ths, phs))
print(U)
H2 = (1/np.sqrt(2))*np.array([[1,1],[1,-1]], dtype=complex)
U = np.kron(H2, H2)
print(U)

import numpy as np

def H4_dft():
    j,k=np.ogrid[:4,:4]
    omega=np.exp(2j*np.pi/4)
    return (omega**(j*k))/2

def H4_real():
    return 0.5*np.array([[1,1,1,1],
                         [1,-1,1,-1],
                         [1,1,-1,-1],
                         [1,-1,-1,1]],dtype=float)

def embed_4_in_6(H4, idx=(0,1,2,3)):
    U=np.eye(6, dtype=complex)
    idx=list(idx)
    comp=[x for x in range(6) if x not in idx]
    P=np.eye(6)[:, idx+comp]
    B=np.block([[H4, np.zeros((4,2))],
                [np.zeros((2,4)), np.eye(2)]])
    return P @ B @ P.T

U = embed_4_in_6(H4_dft(), idx=(0,1,2,3))
# or: H6 = embed_4_in_6(H4_real(), idx=(0,1,2,3))
print(U)

print(len(pairs), pairs)
seq_pred, U_pred = predict_seqT(model, {'n':n,'pairs':pairs,'depth':depth,'STOP':len(pairs)}, U)
fid = fidelity(U, U_pred)
print(U_pred)
print(json.dumps({'true_seq':seq_true,'pred_seq':seq_pred,'fid':fid}, indent=2))



[[-1.    -0.j -0.    +0.j -0.    +0.j -0.    +0.j  0.    +0.j  0.    +0.j]
 [ 0.    +0.j  0.5732+0.j -0.7392-0.j -0.3536+0.j  0.    +0.j  0.    +0.j]
 [ 0.    +0.j -0.7392+0.j -0.2803+0.j -0.6124+0.j  0.    +0.j  0.    +0.j]
 [ 0.    +0.j -0.3536-0.j -0.6124-0.j  0.7071+0.j  0.    +0.j  0.    +0.j]
 [ 0.    +0.j  0.    +0.j  0.    +0.j  0.    +0.j  1.    +0.j  0.    +0.j]
 [ 0.    +0.j  0.    +0.j  0.    +0.j  0.    +0.j  0.    +0.j  1.    +0.j]]
[[ 0.5+0.j  0.5+0.j  0.5+0.j  0.5+0.j]
 [ 0.5+0.j -0.5+0.j  0.5+0.j -0.5+0.j]
 [ 0.5+0.j  0.5+0.j -0.5+0.j -0.5+0.j]
 [ 0.5+0.j -0.5+0.j -0.5+0.j  0.5-0.j]]
[[ 0.5+0.j   0.5+0.j   0.5+0.j   0.5+0.j   0. +0.j   0. +0.j ]
 [ 0.5+0.j   0. +0.5j -0.5+0.j  -0. -0.5j  0. +0.j   0. +0.j ]
 [ 0.5+0.j  -0.5+0.j   0.5-0.j  -0.5+0.j   0. +0.j   0. +0.j ]
 [ 0.5+0.j  -0. -0.5j -0.5+0.j   0. +0.5j  0. +0.j   0. +0.j ]
 [ 0. +0.j   0. +0.j   0. +0.j   0. +0.j   1. +0.j   0. +0.j ]
 [ 0. +0.j   0. +0.j   0. +0.j   0. +0.j   0. +0.j   1. +0.j ]]
5 [(0, 1), (0

In [24]:
H2 = (1/np.sqrt(2))*np.array([[1,1],[1,-1]], dtype=complex)
U = np.kron(H2, H2)
print(U)

[[ 0.5+0.j  0.5+0.j  0.5+0.j  0.5+0.j]
 [ 0.5+0.j -0.5+0.j  0.5+0.j -0.5+0.j]
 [ 0.5+0.j  0.5+0.j -0.5+0.j -0.5+0.j]
 [ 0.5+0.j -0.5+0.j -0.5+0.j  0.5-0.j]]
