# Unitary decomposition

In [97]:
import torch
import torch.nn as nn
import numpy as np

def apply_single_site_unitary(vec, U, i):
    """
    Applies unitary to many qubit state which is a flattened array. 
    
    vec = torch.zeros(2**4)
    vec[2 + 1] = 1
    # vec.reshape(2,2,2)[0,0,1]
    new_vec = apply_single_site_unitary(vec, torch.tensor([[0.,1],[1,0]]),1)
    new_vec[1] # 1

    Args:
        vec (torch.tensor): Flattened array corresponding to many-qubit state.
            psi.reshape((2,2,2,2,2))[0,1,0,1,1] corresponds to the qubit 11010 (it's reverse order)
        U (torch.tensor): Single site unitary to apply. Should be 2x2
        i (int): Site to apply unitary to

    Returns:
        torch.tensor: _description_
    """
    L = int(np.log2(len(vec)))
    vec = vec.reshape([2 for _ in range(L)])
    vec = torch.moveaxis(torch.tensordot(U,vec,dims=([1],[L-1-i])),0,L-1-i)
    return vec.flatten()

def apply_two_site_unitary(vec, U, i, j):
    """_summary_

    Args:
        vec (_type_): _description_
        U (_type_): Two-site unitary to apply. Should be 4x4
        i (_type_): _description_
        j (_type_): _description_

    Returns:
        _type_: _description_
    """
    assert i != j
    L = int(np.log2(len(vec)))
    vec = vec.reshape([2 for _ in range(L)])
    if j > i:
        vec = torch.moveaxis(torch.moveaxis(vec, L-1-i,0),L-1-j,1)
    else:
        vec = torch.moveaxis(torch.moveaxis(vec, L-1-j,0),L-1-i,0)
    vec = vec.reshape((4,-1))
    vec = U @ vec
    vec = vec.reshape([2 for _ in range(L)])
    vec = torch.moveaxis(torch.moveaxis(vec, 1,L-1-i),0,L-1-j)
    return vec.flatten()

class RotationLinear(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # Start with a random matrix to learn
        self.param = nn.Parameter(torch.randn(dim, dim))

    def forward(self, x):
        # QR decomposition to get orthogonal matrix
        q, r = torch.linalg.qr(self.param)
        
        # Optional: enforce right-handed coordinate system (det = +1)
        # Flip sign if determinant is -1
        if torch.det(q) < 0:
            q[:, -1] *= -1

        return x @ q

In [None]:
torch.tensor([[1,2,3],[4,3,2]])@torch.tensor([[1,2],[4,3],[1,1]])

torch.Size([2, 3])

In [100]:
# vec = torch.rand(2**8)
vec = torch.zeros(2**3)
vec[1] = 1
# vec.reshape(2,2,2)[0,0,1]
# vec = vec/torch.sqrt(sum(vec**2))
# new_vec = apply_single_site_unitary(vec, torch.tensor([[0.,1],[1,0]]),1)
new_vec = apply_two_site_unitary(vec, torch.tensor([[1.,0,0,0],[0,1,0,0],[0,0,0,1],[0,0,1,0]]),0,1)
# sum(new_vec**2)
# new_vec[2]
new_vec

tensor([0., 0., 0., 1., 0., 0., 0., 0.])