In [7]:
import numpy as np
import scipy
import sympy
import data_import
import pandas as pd

# Data Import & Pre-processing

In [8]:
pd_data = data_import.read_files_to_pd_dataframe(
    [f"../outs/grid/itr_{i}.xml.out" for i in range(500)]
)

# Policy Iteration Algo

In [9]:

def get_policy(Q, x, b):
    """
    updates policy to the next policy,
    returns error
    """
    policy = []
    err2 = 0
    y = np.dot(Q, x) + b
    for i in range(len(x)):
        if y[i] < x[i]:
            policy.append(i)
        err2 += min(x[i], y[i]) ** 2    # x should be > 0, while y should be as close to 0 as possible
    return err2 ** 0.5, policy


def update_value(policy, Q, b, use_cg = True):
    """
    returns the next x given the policy
    """
    I = np.eye(len(b))
    policy_matrix = np.zeros((b.shape[0], b.shape[0]))
    for p in policy:
        policy_matrix[p,p] = 1
    A = np.dot(
            np.dot(
                policy_matrix,
                Q
            ),
            policy_matrix
        ) + I - policy_matrix
    rhs = -np.dot(policy_matrix, b) 
    lst_sqr_soln = np.linalg.lstsq(A , rhs)[0]
    cg_soln = scipy.sparse.linalg.cg(A , rhs , tol=1e-12)[0]
    if use_cg:
        return cg_soln
    else:
        return lst_sqr_soln


def flow(Q, N, v0, correct_policy, max_itrs = 100, m_tol = 1e-09, CoR = 1, use_cg = True):
    """
    does policy iteration
    returns a bool/vector pair: (converged, x)
    """
    b = np.dot(N.T, (1 + CoR) * v0)
    x = np.zeros(b.shape[0])
    policy = []
    error = 0.0

    found_correct_control = False

    for n_iter in range(max_itrs):

        # have we stumbled across the correct policy/control?
        if correct_policy == policy:
            found_correct_control = True

        error, policy = get_policy(Q, x, b)
        if error <= m_tol:
            return (True, x, found_correct_control)

        x = update_value(policy, Q, b, use_cg)
    
    return (False, x, found_correct_control)



# Calculate the "correct" control set

In [10]:
def get_ipopt_control(x):
    # anything less than 1e-6 is counted as a 0 since thats the tolerance we set in the xml files
    return [i for i in range(len(x)) if x[i] > 1e-6]

pd_data['correct_control'] = [get_ipopt_control(x) for x in pd_data['ipopt_sol']]
pd_data['correct_control']

0                      [0, 1, 6, 7, 8, 9, 11, 12]
1        [0, 1, 2, 3, 4, 6, 7, 9, 11, 12, 13, 14]
2                            [0, 2, 3, 5, 10, 11]
3                         [0, 1, 2, 3, 9, 14, 15]
4                       [0, 2, 4, 5, 6, 7, 9, 14]
                          ...                    
495                      [0, 2, 3, 7, 10, 11, 12]
496                 [0, 1, 3, 10, 11, 13, 14, 15]
497                [0, 1, 2, 4, 6, 7, 15, 17, 20]
498                  [0, 4, 6, 8, 14, 16, 17, 19]
499    [0, 2, 3, 5, 7, 8, 15, 16, 17, 18, 19, 22]
Name: correct_control, Length: 500, dtype: object

In [12]:
for i in range(500):
    Q = pd_data['Q'][i]
    N = pd_data['N'][i]
    v0 = pd_data['v0'][i]
    correct_control = pd_data['correct_control'][i]

    converged, l, found_correct_control = flow(Q, N, v0, correct_control)
    if not converged:
        print(found_correct_control)

  lst_sqr_soln = np.linalg.lstsq(A , rhs)[0]


False
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
True
True
False
False
True
False
False
False
False
False
True
False
False
False
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
