In [1]:
from typing  import Tuple, Optional, List
from math import inf

In [2]:
x_space = [i for i in range(-2,3,1)]
u_space = [i for i in range(-1,2,1)]



In [None]:
def update_state(x_n, u_n):
    if (-2 <= -x_n + 1 + u_n  <= 2):
        x_n1 = -x_n + 1 + u_n 
    elif (-x_n + 1 + u_n  > 2):
        x_n1 = 2
    else:
        x_n1 = -2
    return x_n1

def cost(x, u):
    return 2*abs(x) + abs(u)



In [4]:
class Opt_Policy:
    def __init__(self, state: int, cost: float, control: Optional[int]):
        self.state = state  # The state value
        self.cost = cost    # The minimum cost for this state
        self.control = control  # The optimal control for this state

class DP:
    def __init__(self, num_stages: int, states: List[int]):
        # Initialize memo table as a list of lists, with each list representing a stage
        self.memo = [[None for _ in states] for _ in range(num_stages)]
        self.num_stages = num_stages
        self.states = states

    def get_memo_entry(self, stage: int, x: int) -> Optional[Opt_Policy]:
        # Retrieve the memo entry for a specific stage and state
        state_index = self.states.index(x)
        return self.memo[stage][state_index]

    def set_memo_entry(self, stage: int, x: int, cost: float, control: Optional[int]):
        # Create and store a memo entry for a specific stage and state
        state_index = self.states.index(x)
        self.memo[stage][state_index] = Opt_Policy(state=x, cost=cost, control=control)

    def cost_to_go(self, stage: int, x: int) -> Tuple[float, Optional[int]]:
        """
        Computes the minimum cost to reach the final stage from the given stage and state,
        along with the optimal control input.
        """
        # Check if the cost for the current stage and state is already memoized
        memo_entry = self.get_memo_entry(stage, x)
        if memo_entry is not None:
            return memo_entry.cost, memo_entry.control

        if stage == self.num_stages - 1:
            # Base case: final stage, cost is x squared, no control input
            j_min = x ** 2
            u_opt = None
            self.set_memo_entry(stage, x, j_min, u_opt)
            return j_min, u_opt

        # Initialize minimum cost to infinity and no optimal control
        j_min = inf
        u_opt = None

        # Iterate through all possible control inputs to find the optimal one
        for u_i in u_space:
            # Compute the next state based on current state and control input
            next_x = update_state(x_n=x, u_n=u_i)
            # Compute the total cost: current cost + cost to go from the next state
            j_i = cost(x=x, u=u_i) + self.cost_to_go(stage=stage + 1, x=next_x)[0]
            # Update the minimum cost and optimal control if a lower cost is found
            if j_i < j_min:
                j_min = j_i
                u_opt = u_i

        # Memoize the computed minimum cost and optimal control for the current stage and state
        self.set_memo_entry(stage, x, j_min, u_opt)
        return j_min, u_opt

    def display_costs_and_controls(self):
        # Define the header with columns for each stage
        header = ["State"]
        for stage in range(self.num_stages):
            header.append(f"J_{stage}")
            header.append(f"u_{stage}")

        # Print the header with formatted spacing
        print("{:<6} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}".format(*header))

        # Iterate through each possible state and compute costs and optimal controls
        for x in self.states:
            row = [x]
            for stage in range(self.num_stages):
                j, u = self.cost_to_go(stage=stage, x=x)
                row.append(j)
                row.append(u)

            # Format and print the row with aligned columns
            print("{:<6} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10} {:<10}".format(*row))


# Example of usage
# Assuming X_POSS and U_POSS are defined
calculator = CostToGoCalculator(4, x_space)
calculator.display_costs_and_controls()


State  J_0        u_0        J_1        u_1        J_2        u_2        J_3       
-2     10         0          9          0          8          0          4         
-1     6          -1         5          -1         4          -1         1         
0      3          -1         2          -1         1          -1         0         
1      4          0          3          0          2          0          1         
2      7          1          6          1          5          0          4         
