In [None]:
import numpy as np
import networkx as nx

def givens_rotation(i, j, theta, phi, d):
    R = np.eye(d, dtype=complex)
    c, s = np.cos(theta/2), np.sin(theta/2)*np.exp(1j*phi)
    R[[i, j]][:, [i, j]] = [[c, -np.conj(s)], [s, np.conj(c)]]
    return R

def taqr(U, G):
    U = U.copy().astype(complex)
    d = U.shape[0]
    rotations = []
    order = list(nx.bfs_tree(G, 0))
    for r in range(d-1, 0, -1):
        for c in range(r):
            if abs(U[r, c]) < 1e-12: continue
            path = nx.shortest_path(G, source=r, target=c)
            for (i, j) in zip(path[:-1], path[1:]):
                a, b = U[j, c], U[i, c]
                rnorm = np.hypot(abs(a), abs(b))
                if rnorm == 0: continue
                theta = 2*np.arctan2(abs(b), abs(a))
                phi = np.angle(b) - np.angle(a)
                R = givens_rotation(i, j, theta, phi, d)
                U = R.conj().T @ U
                rotations.append((i, j, theta, phi))
    phases = np.angle(np.diag(U))
    return rotations, phases



In [28]:
import matplotlib.pyplot as plt

d = 4
# U = np.linalg.qr(np.random.randn(d, d) + 1j*np.random.randn(d, d))[0]  # random unitary
H2 = (1/np.sqrt(2))*np.array([[1,1],[1,-1]], dtype=complex)
U = np.kron(H2, H2)
G = nx.Graph()
G.add_nodes_from(range(d))
for i in range(1, d):
    G.add_edge(0, i)  # star topology

rots, phases = taqr(U,G)
print(  "Rotations = \n", np.array(rots))
print(  "Phases = ", phases)
# print("Unitary = \n", U)
# plt.figure()
# plt.imshow(np.real(U), cmap='viridis')
# plt.show()
print(len(rots), "rotations")


Rotations = 
 [[ 3.          0.          1.57079633  0.        ]
 [ 3.          0.          1.57079633  3.14159265]
 [ 0.          1.          1.57079633 -3.14159265]
 [ 3.          0.          1.57079633  3.14159265]
 [ 0.          2.          1.57079633 -3.14159265]
 [ 2.          0.          1.57079633  0.        ]
 [ 2.          0.          1.57079633  0.        ]
 [ 0.          1.          1.57079633 -3.14159265]
 [ 1.          0.          1.57079633  0.        ]]
Phases =  [0.         3.14159265 3.14159265 0.        ]
9 rotations


In [29]:
np.identity(4)

array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]])