In [1]:
import scipy
import numpy as np
import data_import

In [21]:
# another PI implementation

def get_policy(Q, x, b, policy):
    """
    updates policy to the next policy,
    returns error
    """
    err2 = 0
    y = np.dot(Q, x) + b
    for i in range(len(x)):
        choice = 0.0
        if y[i] < x[i]:
            choice = 1
        err2 += min(x[i], y[i]) ** 2    # x should be > 0, while y should be as close to 0 as possible

        policy[i,i] = choice
    return err2 ** 0.5

def update_value(policy, Q, b, use_cg = True):
    """
    returns the next x given the policy
    """
    I = np.eye(len(b))
    A = np.dot(
            np.dot(
                policy,
                Q
            ),
            policy
        ) + I - policy
    rhs = -np.dot(policy, b) 
    lst_sqr_soln = np.linalg.lstsq(A , rhs)[0]
    cg_soln = scipy.sparse.linalg.cg(A , rhs , tol=1e-12)[0]
    if use_cg:
        return cg_soln
    else:
        return lst_sqr_soln


def flow(Q, N, v0, initial_guess = np.array([]), max_itrs = 100, m_tol = 1e-09, CoR = 1, use_cg = True):
    """
    does policy iteration
    returns a bool/vector pair: (converged, x)
    """
    b = np.dot(N.T, (1 + CoR) * v0)
    x = np.zeros(b.shape[0])

    if len(initial_guess) > 0:
        x = initial_guess

    policy = np.zeros((b.shape[0], b.shape[0]))
    error = 0.0

    for n_iter in range(max_itrs):
        error = get_policy(Q, x, b, policy)
        if error <= m_tol:
            return (True, x)
        x = update_value(policy, Q, b, use_cg)
    
    return (False, x)



In [3]:
pd_data = data_import.read_file_to_pd_dataframe("../K4.out")
Q = pd_data['Q'][0]
N = pd_data['N'][0]
v0 = pd_data['v0'][0]

In [23]:
flow(Q, N, v0, initial_guess=np.array([1,0,1,1,0,1]))

  lst_sqr_soln = np.linalg.lstsq(A , rhs)[0]


(True,
 array([ 2.00000000e+00, -7.66542385e-11,  2.00000000e+00,  2.00000000e+00,
        -7.66544606e-11,  2.00000000e+00]))

In [25]:
flow(Q, N, v0, initial_guess=np.array([0,5,0,0,5,0]))

  lst_sqr_soln = np.linalg.lstsq(A , rhs)[0]


(True, array([0.      , 2.828428, 0.      , 0.      , 2.828428, 0.      ]))