In [3]:
from pymtl3 import *

class QUpdateUnit( Component ):
  
  def construct( s ):
    # 1. Define Ports
    s.curr_row    = InPort( Bits3 )
    s.curr_col    = InPort( Bits3 )
    s.curr_action = InPort( Bits2 )
    s.reward      = InPort( Bits16 )
    s.next_row    = InPort( Bits3 )
    s.next_col    = InPort( Bits3 )
    s.start       = InPort( Bits1 )
    s.done        = OutPort( Bits1 )

    # 2. Q-table memory
    s.Q_table = [ [ [ Wire( Bits16 ) for _ in range(4) ] for _ in range(5) ] for _ in range(5) ]

    # 3. State Machine
    s.state = Wire( Bits2 )

    IDLE        = Bits2( 0 )
    READ_MAX    = Bits2( 1 )
    CALC_UPDATE = Bits2( 2 )
    WRITE_BACK  = Bits2( 3 )

    @update_ff
    def fsm_logic():
      if s.reset:
        s.state <<= IDLE
      else:
        if s.state == IDLE:
          if s.start:
            s.state <<= READ_MAX
        elif s.state == READ_MAX:
          s.state <<= CALC_UPDATE
        elif s.state == CALC_UPDATE:
          s.state <<= WRITE_BACK
        elif s.state == WRITE_BACK:
          s.state <<= IDLE

    # 4. Action inside each stage
    # (Later we will add max Q calculation, Q-update math, and SRAM writeback)



In [5]:
from pymtl3 import *

class QUpdateUnit( Component ):
  def construct( s ):

    # Define Ports
    s.curr_row    = InPort( Bits3 )
    s.curr_col    = InPort( Bits3 )
    s.curr_action = InPort( Bits2 )
    s.reward      = InPort( Bits16 )
    s.next_row    = InPort( Bits3 )
    s.next_col    = InPort( Bits3 )
    s.start       = InPort( Bits1 )
    s.done        = OutPort( Bits1 )

    # Q-table memory: 5x5x4 entries
    s.Q_table = [ [ [ Wire( Bits16 ) for _ in range(4) ] for _ in range(5) ] for _ in range(5) ]

    # FSM State Register
    s.state = Wire( Bits2 )

    # FSM States
    IDLE        = Bits2( 0 )
    READ_MAX    = Bits2( 1 )
    CALC_UPDATE = Bits2( 2 )
    WRITE_BACK  = Bits2( 3 )

    # Temporary Variables
    s.Q_sa        = Wire( Bits16 )
    s.Q_max_next  = Wire( Bits16 )
    s.counter     = Wire( Bits2 ) # For finding max Q over 4 actions

    # FSM Controller
    @update_ff
    def fsm():
      if s.reset:
        s.state <<= IDLE
        s.done  <<= 0
        s.counter <<= 0
        s.Q_max_next <<= 0
      else:
        if s.state == IDLE:
          s.done <<= 0
          if s.start:
            s.state <<= READ_MAX
            s.counter <<= 0
            s.Q_max_next <<= s.Q_table[s.next_row][s.next_col][0]

        elif s.state == READ_MAX:
          if s.counter < 3:
            s.counter <<= s.counter + 1
            if s.Q_table[s.next_row][s.next_col][s.counter+1] > s.Q_max_next:
              s.Q_max_next <<= s.Q_table[s.next_row][s.next_col][s.counter+1]
          else:
            s.state <<= CALC_UPDATE

        elif s.state == CALC_UPDATE:
          # Read Q(s,a)
          s.Q_sa <<= s.Q_table[s.curr_row][s.curr_col][s.curr_action]

          # Q_new = (1-alpha)*Q_sa + alpha*(r + gamma*Q_max_next)
          # Simplified here; detailed math in next step
          s.state <<= WRITE_BACK

        elif s.state == WRITE_BACK:
          s.Q_table[s.curr_row][s.curr_col][s.curr_action] <<= s.Q_sa  # Placeholder
          s.done <<= 1
          s.state <<= IDLE
