# Controller

2 controllers are implemented:
- Simple controller: threshold based
- RL controller: reinforcement learning based


In [None]:
import os
import json
import time
from typing import Dict, Any
import numpy as np
import paho.mqtt.client as mqtt
from stable_baselines3 import A2C

# --- Load parameters ---
from parameters import parameter as PARAMS

print('‚úÖ Controller ‚Äî Parameters loaded')
print('   fmu_path:', PARAMS.get('fmu_path'))
print('   step:', PARAMS.get('fmu_step_size'))
print('   observations:', PARAMS.get('observation_names'))
print('   actions:', PARAMS.get('action_names'))

for k,v in PARAMS.items():
    if k in ("action_min", "action_max"):
        print(f"  {k}: shape={getattr(v, 'shape', None)}")
    else:
        print(f"  {k}: {v}")

ACTION_NAMES = list(PARAMS.get("action_names", []))
OBS_NAMES = list(PARAMS.get("observation_names", []))

ACTION_MIN = np.array(PARAMS.get("action_min", np.array([], dtype=np.float64)), dtype=np.float64)
ACTION_MAX = np.array(PARAMS.get("action_max", np.array([], dtype=np.float64)), dtype=np.float64)

print("\nAction dimension:", len(ACTION_NAMES))
print("Observation dimension:", len(OBS_NAMES))

In [None]:
# --- MQTT topics ---------------------------------------------------------
MQTT_BROKER_HOST = os.environ.get('MQTT_BROKER_HOST', 'mosquitto')
MQTT_BROKER_PORT = int(os.environ.get('MQTT_BROKER_PORT', '1883'))

# Base prefix for the platform subscriber
TOPIC_BASE = os.environ.get('MQTT_TOPIC_BASE', 'simulation')

# Per-variable topics expected by the subscriber
TOPIC_OBS_PREFIX = f'{TOPIC_BASE}/observations'
TOPIC_ACT_PREFIX = f'{TOPIC_BASE}/actions'

print('‚úÖ MQTT config')
print('   Host:', MQTT_BROKER_HOST)
print('   Port:', MQTT_BROKER_PORT)
print('   Observation prefix:', TOPIC_OBS_PREFIX)
print('   Action prefix:', TOPIC_ACT_PREFIX)

In [None]:
class SimpleController:
    def __init__(self, radTrig=100, TinTrig=24):
        self.radTrig = radTrig
        self.TinTrig = TinTrig
    
    def compute_action(self, obs: Dict[str, Any]) -> Dict[str, float]:
        """
        measurements expected to include:
        'DNI', 'Tair_z1', 'Tair_z2', 'Tair_z4'
        """
        # If no actions are configured, runs in "monitor only" mode
        if len(self.action_names) == 0:
            return {}

        Rad = float(obs.get('DNI', 0.0))

        Tin_Zone1  = float(obs.get('Tair_z1', 0.0))
        Tin_Zone2      = float(obs.get('Tair_z2', 0.0))
        Tin_Zone4   = float(obs.get('Tair_z4', 0.0))

        # üîΩ Actuator names
        payload = {
            'ShadeStatus_Zone1_Wall2': 0.0,
            'ShadeStatus_Zone1_Wall8': 0.0,
            'ShadeStatus_Zone1_Wall9': 0.0,
            'ShadeStatus_Zone2_Wall2': 0.0,
            'ShadeStatus_Zone2_wall3': 0.0,
            'ShadeStatus_Zone4_Wall2': 0.0
        }

        # üõ† Same rule everywhere: if Tin > TinTrig and Rad > radTrig ‚Üí 7.0 else 0.0

        # Zone 1 ‚Üí 3 shades
        if Tin_Zone1 > self.TinTrig and Rad > self.radTrig:
            payload['ShadeStatus_Zone1_Wall2'] = 7.0
            payload['ShadeStatus_Zone1_Wall8'] = 7.0
            payload['ShadeStatus_Zone1_Wall9'] = 7.0

        # Zone 2 ‚Üí 2 shades
        if Tin_Zone2 > self.TinTrig and Rad > self.radTrig:
            payload['ShadeStatus_Zone2_Wall2'] = 7.0
            payload['ShadeStatus_Zone2_wall3'] = 7.0

        # Zone 4
        if Tin_Zone4 > self.TinTrig and Rad > self.radTrig:
            payload['ShadeStatus_Zone4_Wall2'] = 7.0

        return payload

In [None]:
class MyRLController:
    def __init__(self, model_path="trained_a2c_model"):
        # Load the trained model
        self.model = A2C.load(model_path)
        
        # Define the exact order of observations the model was trained on
        self.observation_names = [
            'Tair_z1', 'Tair_z2', 'Tair_z4', 'T_out', 'DNI', 'DistrictHeating', 
            'DistrictCooling', 'ShadeStatus_Zone1_Wall2', 'ShadeStatus_Zone1_Wall8', 
            'ShadeStatus_Zone1_Wall9', 'ShadeStatus_Zone2_Wall2', 'ShadeStatus_Zone2_Wall3', 
            'ShadeStatus_Zone4_Wall2'
        ]
        
        # Define the action names in the order they appear in the model output
        self.action_names = [
            'ShadeStatus_Zone1_Wall2', 'ShadeStatus_Zone1_Wall8', 'ShadeStatus_Zone1_Wall9',
            'ShadeStatus_Zone2_Wall2', 'ShadeStatus_Zone2_Wall3', 'ShadeStatus_Zone4_Wall2'
        ]

    def compute_action(self, measurements: dict) -> dict:
        # If no actions are configured, runs in "monitor only" mode
        if len(self.action_names) == 0:
            return {}
        
        # Prepare the observation vector in the correct order
        # Missing values are defaulted to 0.0 to prevent crashes
        obs_list = [float(measurements.get(name, 0.0)) for name in self.observation_names]
        obs = np.array(obs_list, dtype=np.float32)

        # Predict the action using the loaded model
        # action will be an array of 0s and 1s (MultiDiscrete)
        action, _ = self.model.predict(obs, deterministic=True)

        # Map discrete actions {0, 1} to physical values {0.0, 7.0}
        # The notebook uses: 0 -> 0.0 and 1 -> 7.0
        physical_actions = np.where(action == 0, 0.0, 7.0)

        # Construct the return payload
        payload = {
            name: float(val) 
            for name, val in zip(self.action_names, physical_actions)
        }

        return payload

In [None]:
# --- MQTT runtime: receive observations and publish actions (if any) ---

obs_buffer = {}
_last_ts = None

def buffer_is_ready():
    if not obs_buffer:
        return False
    if OBS_NAMES:
        return all(k in obs_buffer for k in OBS_NAMES)
    return True

def on_connect(client, userdata, flags, rc):
    print(f"‚úÖ Connected to MQTT broker {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT} (rc={rc})")
    client.subscribe(TOPIC_OBS_PREFIX, qos=1)
    print(f"‚úÖ Subscribed to topic: {TOPIC_OBS_PREFIX}")

def on_disconnect(client, userdata, rc):
    print(f'‚ÑπÔ∏è Controlled Sim ‚Äî Disconnected (rc={rc})')

def on_message(client, userdata, msg):
    global obs_buffer, _last_ts

    topic = msg.topic
    try:
        data = json.loads(msg.payload.decode("utf-8"))
    except Exception as e:
        print(f"‚ö†Ô∏è JSON decode error on topic={topic}: {e}")
        return

    # --- Per-variable schema: simulationRL/observations/<name> ---
    # payload expected: {"value": <float>, "timestamp": <unix>}
    parts = topic.split("/")
    if len(parts) >= 3 and parts[-2] == "observations":
        name = parts[-1]

        if isinstance(data, dict) and "value" in data:
            val = data.get("value")
            ts = data.get("timestamp", None)

            # reset buffer when timestamp changes (new simulation step)
            if ts is not None and ts != _last_ts:
                obs_buffer = {}
                _last_ts = ts

            # store numeric value
            try:
                obs_buffer[name] = float(val)
            except Exception:
                obs_buffer[name] = val

        return

    print(f"‚ö†Ô∏è Unhandled message format on topic={topic}: type={type(data)}")

client = mqtt.Client()
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_message = on_message
client.connect(MQTT_BROKER_HOST, MQTT_BROKER_PORT, keepalive=60)
client.loop_start()

# üîß Istance of the controller
controller = SimpleController()

print("\nController running.")
if len(ACTION_NAMES) == 0:
    print("üîé Open-loop mode: action_names is empty -> no actions will be published.")
else:
    print(f"üéÆ Closed-loop mode: will publish {len(ACTION_NAMES)} actions to {TOPIC_ACT_PREFIX}")

In [None]:
# üîÅ Main loop
try:
    while True:
        if buffer_is_ready():
            short_obs = {k: obs_buffer.get(k) for k in list(obs_buffer.keys())[:6]}
            print(f"[OBS] keys={len(obs_buffer)} sample={short_obs}")
        
            action_dict = controller.compute_action(obs_buffer)
        
            # Publish the action to the MQTT broker only if actions exist
            if action_dict:
                client.publish(TOPIC_ACT_PREFIX, json.dumps(action_dict), qos=1)
                print(f"[ACT] {action_dict}")
        
            # Clear the observation buffer so that we wait for a new observation
            # before computing the next action.
            obs_buffer = {}

        # Small sleep to avoid busy-looping while waiting for new messages
        time.sleep(0.05)

except KeyboardInterrupt:
    # Graceful shutdown on Ctrl+C
    print("Stopping controller...")
    client.loop_stop()
    client.disconnect()