In [5]:
import scipy
import numpy as np
import data_import
import matplotlib.pyplot as plt

In [21]:
def get_policy(Q, x, b, policy):
    """
    updates policy to the next policy,
    returns error
    """
    err2 = 0
    y = np.dot(Q, x) + b
    for i in range(len(x)):
        choice = 0.0
        if y[i] < x[i]:
            choice = 1
        err2 += min(x[i], y[i]) ** 2    # x should be > 0, while y should be as close to 0 as possible

        policy[i,i] = choice
    return err2 ** 0.5

def update_value(policy, Q, b, use_cg = True):
    """
    returns the next x given the policy
    """
    I = np.eye(len(b))
    A = np.dot(
            np.dot(
                policy,
                Q
            ),
            policy
        ) + I - policy
    rhs = -np.dot(policy, b) 
    print()
    lst_sqr_soln = np.linalg.lstsq(A , rhs
        , 1e-10 / np.max(A)
    )[0]
    cg_soln, cg_conv = scipy.sparse.linalg.cg(A , rhs , tol=1e-10
        # , maxiter=len(rhs) * 10 # this is default... see https://github.com/scipy/scipy/blob/v1.9.3/scipy/sparse/linalg/_isolve/iterative.py#L298-L385
    )
    if cg_conv != 0:
        print("CG DID NOT CONVERGE!")
    if use_cg:
        return cg_soln
    else:
        return lst_sqr_soln


def flow(Q, N, v0, max_itrs = 100, m_tol = 1e-09, CoR = 1, use_cg = True):
    """
    does policy iteration
    returns a bool/vector pair: (converged, x)
    """
    b = np.dot(N.T, (1 + CoR) * v0)
    x = np.zeros(b.shape[0])
    policy = np.zeros((b.shape[0], b.shape[0]))
    error = 1000.0

    diff_policy = False

    for n_iter in range(max_itrs):
        error = get_policy(Q, x, b, policy)
        if error <= m_tol:
            return (True, x, diff_policy)

        x = update_value(policy, Q, b, use_cg)
        x_prime = update_value(policy, Q, b, not use_cg)

        # if our choices of x are different depending on the inner solver we used,
        # only then do we want to check if they resulted in different policies
        if not np.allclose(x, x_prime):
            policy1= np.zeros((b.shape[0], b.shape[0]))
            policy2= np.zeros((b.shape[0], b.shape[0]))
            error1 = get_policy(Q, x, b, policy1)
            error2 = get_policy(Q, x_prime, b, policy2)
            # error = error1
            if not np.allclose(policy1, policy2):
                diff_policy = True
    
    return (False, x, diff_policy)


In [7]:
pd_data = data_import.read_files_to_pd_dataframe(
    [f"../outs/grid/itr_{i}.xml.out" for i in range(500)]
)

In [25]:

np.set_printoptions(precision=2, linewidth=1000)
pi_cg_soln = []
pi_ls_soln = []
pi_cg_converged = []
pi_ls_converged = []
for i in range(500):
    Q = pd_data['Q'][i]
    N = pd_data['N'][i]
    v0= pd_data['v0'][i]
    cg_converged, cg_soln, _ = flow(Q, N, v0, use_cg=True)
    ls_converged, ls_soln, diff_policy = flow(Q, N, v0, use_cg=False)

    if diff_policy:
        print("Diff policy!", i)
    if cg_converged != ls_converged:
        print("difference in convergence!", i, "cg converged:", cg_converged)

    if not np.allclose(cg_soln, ls_soln):
        print("Diff solns", i)
        print("   cg_converged", cg_converged, "ls converged", ls_converged)
        print("   solns:")
        tol = 1e-07
        # set small vals = 0
        cg_soln[abs(cg_soln) < tol] = 0
        ls_soln[abs(ls_soln) < tol] = 0
        print(cg_soln)
        print(ls_soln)


Diff policy! 65
Diff solns 65
   cg_converged False ls converged False
   solns:
[ 1011.36129836 -1010.65060439    56.66265666     0.             0.             0.             0.          1020.77096783     0.             0.            -7.19172141     0.           -14.88153994     0.             0.             0.            50.73286006     0.             0.            27.97760728     0.             0.             7.81144687    34.89871406     0.             0.            25.27791315    23.81332873     7.3602681 ]
[  28.97619255    0.            0.           12.93419262    0.            0.            0.            0.          -56.21486478 -154.30921817    0.            0.          236.1837368     0.            0.         -365.44263327   17.90434443  375.36834468    0.           27.46568561    0.           26.05342476  -34.95579005   41.88332969    0.         -293.02121152 -229.16329821  -25.53848511    0.        ]
Diff policy! 91
difference in convergence! 91 cg converged: True
Diff soln