In [26]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

def Wigner(rho):
    """accepts density matrix and returns Wigner function
    over d x d toroidal lattice phase space"""

    print(rho)
    d = np.shape(rho)[0]
    root_d = np.exp(1j * 2 * np.pi / d)

    # build generalized Pauli X & Z matrices
    # Pauli X
    X = np.zeros((d,d), dtype=complex)
    X[0,-1] = 1
    for i in range(d-1):
        X[i+1,i] = 1

    # Pauli Z
    Z = np.zeros((d,d), dtype=complex)
    for i in range(d):
        Z[i,i] = root_d**i

    # define Weyl displacement operators from Paulis
    def Weyl(vec):
        """accepts vector vec = [a,b] and returns 
        Weyl displacement operator T_[a,b]"""

        #  extract displacement vector and enforce periodic boundary
        # (i.e. stay on the toroidal lattice)
        v1, v2 = vec
        v1 = np.mod(v1, d)
        v2 = np.mod(v2, d)

        T = root_d**(- v1 * v2 / 2) * ( np.linalg.matrix_power(Z, v1) @ np.linalg.matrix_power(X, v2) )

        return T

    # build un-displaced parity operator
    T0 = np.zeros((d,d), dtype=complex)
    for v1 in range(d):
        for v2 in range(d):
            T0 += 1/d * Weyl([v1,v2])

    # evaluate Wigner function (note the double use of v1 and v2
    # as both vectors and array indices)
    wigner = np.zeros((d,d))
    for v1 in range(d):
            for v2 in range(d):
                value = 1/d * np.trace( rho @ Weyl([v1,v2]) @ T0 @ Weyl([v1,v2]).conj().T )
                # print(value)
                if np.abs(value.imag) < 1e-1: value = value.real
                else: raise Exception("imaginary value in Wigner function")
                wigner[v1,v2] = value

    return wigner

def main():
    dim = 9
    state = np.zeros(dim, dtype=complex)
    state[0] = 1/np.sqrt(2)
    state[-2] = 1/np.sqrt(2)
    rho = np.outer(state, np.conjugate(state))
    wigner = Wigner(rho)

    maxval = abs(wigner).max()
    im = plt.imshow(wigner, cmap = 'RdBu_r', vmax=maxval, vmin=-maxval, origin='lower')
    #plt.colorbar(im)
    #plt.show()
    print(wigner)

if __name__ == "__main__":
    main()

[[0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0.5+0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0.5+0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]]


AttributeError: module 'matplotlib.pyplot' has no attribute 'im'