In [138]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
from copy import deepcopy
import random
import mpmath as mp

mp.dps = 200

In [158]:
def norm(a):
    return [b / mp.norm(a) for b in a]

def reconstruct(mps_in):
    state_vec = []

    mps = deepcopy(mps_in)

    N = len(mps)

    # this converts to a different index order
    # where the physical index is the outermost index and the bond index is contained
    # also puts into canonical form we are used to seeing
    mps[-1] = np.swapaxes(mps[-1], 0, 1)
    for i in range(N-1):
        mps[i][0] = np.swapaxes(mps[i][0], 0, 1)

    for i in range(2**N):
        ind = [int(bit) for bit in format(i, '0{}b'.format(N))]
        prod = [[row[ind[-1]]] for row in mps[-1]]
        for j in range(N-2, -1, -1):
            tensors = mps[j]
            prod = tensors[1] @ prod
            prod = tensors[0][ind[j]] @ prod

        state_vec.append(prod[0])

    return norm(state_vec)

def mps(state_v, chi=1000000):
    mps_v = []
    N = int(np.log2(len(state_v)))
    right = np.reshape(state_v, (2, 2**(N-1)))
    for i in range(N):
        if i == N-1:
            mps_v.append(right)
            continue
        gamma, S, right = mp.svd(mp.matrix(right), full_matrices=False)
        gamma = np.array(gamma.tolist())
        right = np.array(right.tolist())
        # left and right most gammas are our MPS caps, only have one bond index
        if i > 0 and i < N-1:
            if chi < len(S) and chi >= 1:
                gamma = gamma[:, :chi]
                S = S[:chi]
                right = right[:chi, :]

            gamma = np.reshape(gamma, (int(gamma.shape[0]/2), 2, gamma.shape[1]))

        # don't want to reshape this way when there is only 1 column
        if right.shape[1] > 2:
            right = np.reshape(right, (int(right.shape[0]*2), int(right.shape[1]/2)))

        S = norm(S)
        lambd = np.diag(S)
        mps_v.append([gamma, lambd])
    return mps_v

def haar_state(N):
    # for some reason, when we use complex amplitudes the fidelity caps out at ~0.85
    # this is likely not an issue with the code, since manually creating simple complex
    # valued states leads to 1.0 fidelity. Instead, something about randomly generating
    # more complex states seems to mess the precision
    # state = [complex(np.random.normal(), 0) for _ in range(2**N)]
    state = [mp.mpc(mp.rand()) for _ in range(2**N)]
    return norm(state)

In [169]:
st = haar_state(5)
print(st)
a = mps(st)
# print(a)
b = reconstruct(a)
np.dot(np.abs(st), np.abs(b))

[mpc(real='0.013612417346171916', imag='0.0'), mpc(real='0.079847906914325278', imag='0.0'), mpc(real='0.011026112306426977', imag='0.0'), mpc(real='0.22920773494115143', imag='0.0'), mpc(real='0.14449234445694473', imag='0.0'), mpc(real='0.09922853514160733', imag='0.0'), mpc(real='0.24973143431553815', imag='0.0'), mpc(real='0.2535282740557091', imag='0.0'), mpc(real='0.26848079825873616', imag='0.0'), mpc(real='0.27935702856878297', imag='0.0'), mpc(real='0.1578160112028241', imag='0.0'), mpc(real='0.1186042612020792', imag='0.0'), mpc(real='0.14793883388064433', imag='0.0'), mpc(real='0.17048832496829205', imag='0.0'), mpc(real='0.20265863367502199', imag='0.0'), mpc(real='0.30896422008443802', imag='0.0'), mpc(real='0.24957357485987031', imag='0.0'), mpc(real='0.19515319968619754', imag='0.0'), mpc(real='0.16811497110250215', imag='0.0'), mpc(real='0.14210398085595963', imag='0.0'), mpc(real='0.1255994061224936', imag='0.0'), mpc(real='0.24419015556944093', imag='0.0'), mpc(real='

mpf('0.80889057974930423')