In [72]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
from copy import deepcopy
import random
import mpmath as mp

mp.dps = 200

In [119]:
def norm(a):
    return [b / mp.norm(a) for b in a]

def reconstruct(mps_in):
    state_vec = []

    mps = deepcopy(mps_in)

    N = len(mps)

    # this converts to a different index order
    # where the physical index is the outermost index and the bond index is contained
    # also puts into canonical form we are used to seeing
    mps[-1] = np.swapaxes(mps[-1], 0, 1)
    for i in range(N-1):
        mps[i][0] = np.swapaxes(mps[i][0], 0, 1)

    for i in range(2**N):
        ind = [int(bit) for bit in format(i, '0{}b'.format(N))]
        prod = [[row[ind[-1]]] for row in mps[-1]]
        for j in range(N-2, -1, -1):
            tensors = mps[j]
            prod = tensors[1] @ prod
            prod = tensors[0][ind[j]] @ prod

        state_vec.append(prod[0])

    return norm(state_vec)

def mps(state_v, chi=1000000):
    mps_v = []
    N = int(np.log2(len(state_v)))
    right = np.reshape(state_v, (2, 2**(N-1)))
    for i in range(N):
        if i == N-1:
            mps_v.append(right)
            continue
        gamma, S, right = mp.svd(mp.matrix(right), full_matrices=False)
        gamma = np.array(gamma.tolist())
        right = np.array(right.tolist())
        # left and right most gammas are our MPS caps, only have one bond index
        if i > 0 and i < N-1:
            if chi < len(S) and chi >= 1:
                gamma = gamma[:, :chi]
                S = S[:chi]
                right = right[:chi, :]

            gamma = np.reshape(gamma, (int(gamma.shape[0]/2), 2, gamma.shape[1]))

        # don't want to reshape this way when there is only 1 column
        if right.shape[1] > 2:
            right = np.reshape(right, (int(right.shape[0]*2), int(right.shape[1]/2)))

        S = norm(S)
        lambd = np.diag(S)
        mps_v.append([gamma, lambd])
    return mps_v

def haar_state(N):
    # for some reason, when we use complex amplitudes the fidelity caps out at ~0.85
    # this is likely not an issue with the code, since manually creating simple complex
    # valued states leads to 1.0 fidelity. Instead, something about randomly generating
    # more complex states seems to mess the precision
    # state = [complex(np.random.normal(), 0) for _ in range(2**N)]
    state = [mp.mpf(mp.rand()) for _ in range(2**N)]
    return norm(state)

In [137]:
st = haar_state(7)
print(st)
a = mps(st)
# print(a)
b = reconstruct(a)
np.dot(np.abs(st), np.abs(b))

[mpf('0.03684769306598322'), mpf('0.11060383554455504'), mpf('0.12587992515612484'), mpf('0.006280644077789937'), mpf('0.057423536317673891'), mpf('0.029273529062218119'), mpf('0.13381976109385763'), mpf('0.015435907953503244'), mpf('0.091932120794942951'), mpf('0.14917471105820354'), mpf('0.03955406694888073'), mpf('0.051934360978476451'), mpf('0.053446425046399473'), mpf('0.13496521444724396'), mpf('0.12148169646596932'), mpf('0.062604250269056444'), mpf('0.12007601141558182'), mpf('0.076997571594487849'), mpf('0.021466230100043789'), mpf('0.10969446671968837'), mpf('0.14885988915991932'), mpf('0.026757474803940924'), mpf('0.14230757074389677'), mpf('0.1166919166850018'), mpf('0.14050785785678374'), mpf('0.04313508369057098'), mpf('0.14479796150852689'), mpf('0.13867688619016397'), mpf('0.094047283115252789'), mpf('0.12710155644254523'), mpf('0.12033077096555972'), mpf('0.084791446827964856'), mpf('0.03049732979371949'), mpf('0.026462616851963851'), mpf('0.081281105066550327'), mpf('

mpf('0.99999999999999978')