In [None]:
import numpy as np
from scipy.stats import unitary_group
from collections import deque

In [None]:
def Dagger(x):
    return np.conj(np.transpose(x))
def Kron(*mats):
    """
    Extend the standard Kronecker product, to a list of matrices, where the Kronecker product is recursively taken from left to right.
    """
    if len(mats) < 2:
        return mats[0]
    return np.kron(mats[0], Kron(*(mats[1:])))
def circular_shift(li,start,end,direction="right"):
    r"""
    Circular shifts a part of the list `li` between start and end indices
    """
    d = deque(li[start:end+1])
    if(direction=="right"):
        d.rotate(1)
    else:
        d.rotate(-1)
    return li[:start] + list(d) + li[end+1:]


In [391]:
def fix_index_after_tensor(tensor,indices_changed):
    r"""
    Tensor product alters the order of indices. This function helps reorder to fix them back. 
    """
    n = len(tensor.shape)-1
    perm_list = list(range(len(tensor.shape)))
    n_changed = len(indices_changed)
#     print("Fixing indices order: ", indices_changed)
    for i in range(len(indices_changed)):
        index = indices_changed[i]
        perm_list = circular_shift(perm_list,index,n-n_changed+i+1,"right")
#         print(perm_list)
    return(np.transpose(tensor,perm_list))

    

In [405]:
def get_PTMelem_ij(krausdict,Pi,Pj,n_qubits):
    r"""
    Assumes Paulis Pi,Pj to be a tensor on n_qubits
    Calculates Tr(Pj Eps(Pi))
    
    Assumes qubits
    Assumes kraus ops to be square with dim 2**(n_qubits in support)
    """
    for key,val in krausdict.items():
        indices = key + tuple(map(lambda x:x+n_qubits,key))
        for kraus in val:
            kraus_reshape_dims = [2]*(2*int(np.log2(kraus.shape[0])))
            indices_Pi = indices[len(indices)//2:] 
            indices_kraus = range(len(kraus_reshape_dims)//2)
            Pi = np.tensordot(Pi,Dagger(kraus).reshape(kraus_reshape_dims),(indices_Pi,indices_kraus))
            Pi = fix_index_after_tensor(Pi,indices_Pi)
            indices_Pi = indices[:len(indices)//2]
            indices_kraus = range(len(kraus_reshape_dims))[len(kraus_reshape_dims)//2:]
            Pi = np.tensordot(Pi,kraus.reshape(kraus_reshape_dims),(indices_Pi,indices_kraus))
            Pi = fix_index_after_tensor(Pi,indices_Pi)
        # take dot product with Pj and trace
        indices_Pi = indices[len(indices)//2:] 
        indices_Pj = indices[:len(indices)//2]
        Pi_times_Pj = np.tensordot(Pi,Pj,(indices_Pi,indices_Pi))
        # Take trace
        einsum_inds = list(range(len(Pi_times_Pj.shape)//2))+list(range(len(Pi_times_Pj.shape)//2))
        trace_val = np.einsum(Pi_times_Pj,einsum_inds)
    return trace_val
    
    

In [411]:
n_qubits = 4
Pi = Kron(*[np.eye(2)]*nqubits)
Pi = Pi.reshape([2]*(2*int(np.log2(Pi.shape[0]))))
Pj = Pi
theta = 0.2
Sztheta = [[np.exp(1j*theta),0],[0,np.exp(-1j*theta)]]
Uprime = Kron(Sztheta,Sztheta)
#Kraus dict has the format ("support": list of kraus ops on the support)
krausdict = {(0,1):[Uprime,Uprime],(1,2):[Uprime]}
PTMij = get_PTMelem_ij(krausdict,Pi,Pj,n_qubits)
print(PTMij)

(16+0j)
