# Dataset Generation


In [None]:
import random
import numpy as np
import json

DATASET_SIZE = 1
OUTPUT_FILE = "demonstrations.txt"

# Item categories
CONTAINERS = ["pot", "pan", "plate", "bowl", "glass", "measuring_cup"]
LIQUID_INGREDIENTS = ["milk", "oil"]
SOLID_INGREDIENTS  = ["tomato", "garlic", "onion", "mushroom", "lettuce",
                        "cheese", "rice", "yoghurt", "strawberries", "banana",
                        "egg", "fish", "chicken", "meat",
                        "salt", "spice1", "spice2", "mixture"]
INGREDIENTS = LIQUID_INGREDIENTS + SOLID_INGREDIENTS
ITEMS = CONTAINERS + INGREDIENTS
CUTTABLES  = ["tomato", "onion", "mushroom", "lettuce", "banana", "strawberries", "chicken", "fish", "cheese"]
GRATABLE   = ["cheese"]
COOKABLES  = ["meat", "egg", "rice", "tomato", "onion", "mushroom", "chicken", "fish", "mixture"]
SEASONINGS = ["salt", "spice1", "spice2", "garlic"]
# Locations (places, not tools)
LOCATIONS = ["storage", "prep_station", "cooking_station", "plating_station", "serving_station", "washing_station", "blending_station"]
# Tools at locations
TOOLS = {"prep_station":["cutting_board", "grater"],    "cooking_station":  ["stove"],      "washing_station":  ["sink"],       "blending_station": ["blender"]}


class StateTracker:
    def __init__(self):
        self.feature_map = {}
        idx = 0
        # 1. Item locations
        for item in ITEMS:
            for loc in LOCATIONS:
                self.feature_map[f"{item}_at_{loc}"] = idx
                idx += 1
        # 2. Containment (container_contains_ingredient)
        for container in CONTAINERS:
            for ingredient in INGREDIENTS:
                self.feature_map[f"{container}_contains_{ingredient}"] = idx
                idx += 1
        # 3. Cut status
        for item in CUTTABLES:
            self.feature_map[f"{item}_cut"] = idx
            idx += 1
        # 4. Grated status
        for item in GRATABLE:
            self.feature_map[f"{item}_grated"] = idx
            idx += 1
        # 5. Cooked status
        for item in COOKABLES:
            self.feature_map[f"{item}_cooked"] = idx
            idx += 1
        # 6. Seasoned status (any ingredient can be seasoned)
        for item in INGREDIENTS:
            self.feature_map[f"{item}_seasoned"] = idx
            idx += 1
        # 7. Washed status
        for item in ITEMS:
            self.feature_map[f"{item}_washed"] = idx
            idx += 1
        # 8. Tool states
        self.feature_map["stove_on"]   = idx; idx += 1
        self.feature_map["sink_on"]    = idx; idx += 1
        self.feature_map["blender_on"] = idx; idx += 1
        # 9. Task completion
        self.feature_map["plate_served"] = idx; idx += 1

        self.n_features = idx
        self.reset()
        print(f"Total state features: {self.n_features}")
        print(
            f"Feature dimensions: "
            f"Locations={len(ITEMS)*len(LOCATIONS)}, "
            f"Containment={len(CONTAINERS)*len(INGREDIENTS)}, "
            f"Cut={len(CUTTABLES)}, Grated={len(GRATABLE)}, "
            f"Cooked={len(COOKABLES)}, Seasoned={len(INGREDIENTS)}, "
            f"Washed={len(ITEMS)}, Tools=3, Served=1"
        )

    def reset(self):
        self.current_state = np.zeros(self.n_features, dtype=int)
        for item in ITEMS: self.set_feature(f"{item}_at_storage", 1)

    def set_feature(self, key, value):
        if key in self.feature_map: self.current_state[self.feature_map[key]] = value

    def get_feature(self, key):
        return self.current_state[self.feature_map.get(key, -1)] if key in self.feature_map else 0

    def get_state_vector(self):
        return self.current_state.copy()

    def get_item_location(self, item):
        for loc in LOCATIONS:
            if self.get_feature(f"{item}_at_{loc}") == 1: return loc
        return None

    def is_contained(self, item):
        for container in CONTAINERS:
            if self.get_feature(f"{container}_contains_{item}") == 1: return container
        return None

    def apply_action(self, action_str):
        action_str = action_str.strip()
        def _parse(s):
            content = s[s.find("(")+1:s.find(")")]
            return [p.strip() for p in content.split(",")]
        def _loc(part): 
            return part.split("=")[1] if "=" in part else part
        if action_str.startswith("transfer"):
            parts    = _parse(action_str)
            item     = parts[0]
            from_loc = _loc(parts[1])
            to_loc   = _loc(parts[2])
            if self.get_feature(f"{item}_at_{from_loc}") != 1:                          raise ValueError(f"Precondition failed: {item} not at {from_loc}")
            if self.is_contained(item):                                                 raise ValueError(f"Precondition failed: {item} is contained, use unload first")
            self.set_feature(f"{item}_at_{from_loc}", 0)
            self.set_feature(f"{item}_at_{to_loc}",   1)

        elif action_str.startswith("load"):
            parts     = _parse(action_str)
            item      = parts[0]
            container = parts[1]
            location  = parts[2]
            if self.get_feature(f"{item}_at_{location}") != 1:                          raise ValueError(f"Precondition failed: {item} not at {location}")
            if self.get_feature(f"{container}_at_{location}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {location}")
            if self.is_contained(item):                                                 raise ValueError(f"Precondition failed: {item} already contained")
            self.set_feature(f"{container}_contains_{item}", 1)

        elif action_str.startswith("unload"):
            parts     = _parse(action_str)
            item      = parts[0]
            container = parts[1]
            location  = parts[2]
            if self.get_feature(f"{container}_contains_{item}") != 1:                   raise ValueError(f"Precondition failed: {item} not in {container}")
            if self.get_feature(f"{container}_at_{location}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {location}")
            self.set_feature(f"{container}_contains_{item}", 0)

        elif action_str.startswith("move_container"):
            parts     = _parse(action_str)
            container = parts[0]
            from_loc  = _loc(parts[1])
            to_loc    = _loc(parts[2])
            if self.get_feature(f"{container}_at_{from_loc}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {from_loc}")
            self.set_feature(f"{container}_at_{from_loc}", 0)
            self.set_feature(f"{container}_at_{to_loc}",   1)
            for ingredient in INGREDIENTS:
                if self.get_feature(f"{container}_contains_{ingredient}") == 1:
                    self.set_feature(f"{ingredient}_at_{from_loc}", 0)
                    self.set_feature(f"{ingredient}_at_{to_loc}",   1)

        elif action_str.startswith("cut"):
            parts    = _parse(action_str)
            item     = parts[0]
            location = parts[1] if len(parts) > 1 else "prep_station"
            if item not in CUTTABLES:                                                   raise ValueError(f"Precondition failed: {item} is not cuttable")
            if self.get_feature(f"{item}_at_{location}") != 1:                          raise ValueError(f"Precondition failed: {item} not at {location}")
            self.set_feature(f"{item}_cut", 1)

        elif action_str.startswith("grate"):
            parts    = _parse(action_str)
            item     = parts[0]
            location = parts[1] if len(parts) > 1 else "prep_station"
            if item not in GRATABLE:                                                    raise ValueError(f"Precondition failed: {item} is not gratable")
            if self.get_feature(f"{item}_at_{location}") != 1:                          raise ValueError(f"Precondition failed: {item} not at {location}")
            self.set_feature(f"{item}_grated", 1)

        elif action_str.startswith("cook ") or action_str.startswith("cook("):
            parts     = _parse(action_str)
            item      = parts[0]
            container = parts[1] if len(parts) > 1 else "pot"
            location  = parts[2] if len(parts) > 2 else "cooking_station"
            if item not in COOKABLES:                                                   raise ValueError(f"Precondition failed: {item} is not cookable")
            if self.get_feature(f"{container}_contains_{item}") != 1:                   raise ValueError(f"Precondition failed: {item} not in {container}")
            if self.get_feature(f"{container}_at_{location}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {location}")
            if location == "cooking_station" and self.get_feature("stove_on") != 1:     raise ValueError("Precondition failed: stove not on")
            self.set_feature(f"{item}_cooked", 1)

        elif action_str.startswith("cook_contents"):
            parts     = _parse(action_str)
            container = parts[0]
            location  = parts[1] if len(parts) > 1 else "cooking_station"
            if self.get_feature(f"{container}_at_{location}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {location}")
            if location == "cooking_station" and self.get_feature("stove_on") != 1:     raise ValueError("Precondition failed: stove not on")
            for ingredient in INGREDIENTS:
                if self.get_feature(f"{container}_contains_{ingredient}") == 1:
                    if ingredient in COOKABLES: self.set_feature(f"{ingredient}_cooked", 1)

        elif action_str.startswith("combine"):
            parts     = _parse(action_str)
            container = parts[0]
            location  = parts[1] if len(parts) > 1 else None
            if location and self.get_feature(f"{container}_at_{location}") != 1:        raise ValueError(f"Precondition failed: {container} not at {location}")
            contained = [ing for ing in INGREDIENTS if ing != "mixture" and self.get_feature(f"{container}_contains_{ing}") == 1]
            if len(contained) < 2:                                                      raise ValueError(f"Precondition failed: combine requires >=2 ingredients in {container}, found {contained}")
            for ing in contained: self.set_feature(f"{container}_contains_{ing}", 0)
            self.set_feature(f"{container}_contains_mixture", 1)
            if location: self.set_feature(f"mixture_at_{location}", 1)

        elif action_str.startswith("season_container"):
            parts     = _parse(action_str)
            container = parts[0]
            seasoning = parts[1]
            location  = parts[2] if len(parts) > 2 else None
            if seasoning not in SEASONINGS:                                             raise ValueError(f"Precondition failed: {seasoning} is not a seasoning")
            if location and self.get_feature(f"{container}_at_{location}") != 1:        raise ValueError(f"Precondition failed: {container} not at {location}")
            for ing in INGREDIENTS:
                if self.get_feature(f"{container}_contains_{ing}") == 1: self.set_feature(f"{ing}_seasoned", 1)

        elif action_str.startswith("season"):
            parts     = _parse(action_str)
            target    = parts[0]
            seasoning = parts[1]
            location  = parts[2] if len(parts) > 2 else None
            if seasoning not in SEASONINGS:                                             raise ValueError(f"Precondition failed: {seasoning} is not a seasoning")
            if location and self.get_feature(f"{target}_at_{location}") != 1:           raise ValueError(f"Precondition failed: {target} not at {location}")
            self.set_feature(f"{target}_seasoned", 1)

        elif action_str.startswith("pour"):
            parts          = _parse(action_str)
            liquid         = parts[0]
            from_container = parts[1]
            to_container   = parts[2]
            location       = parts[3] if len(parts) > 3 else None
            if liquid not in LIQUID_INGREDIENTS:                                        raise ValueError(f"Precondition failed: {liquid} is not a liquid")
            if self.get_feature(f"{from_container}_contains_{liquid}") != 1:            raise ValueError(f"Precondition failed: {liquid} not in {from_container}")
            if location:
                if self.get_feature(f"{from_container}_at_{location}") != 1:            raise ValueError(f"Precondition failed: {from_container} not at {location}")
                if self.get_feature(f"{to_container}_at_{location}") != 1:              raise ValueError(f"Precondition failed: {to_container} not at {location}")
            self.set_feature(f"{from_container}_contains_{liquid}", 0)
            self.set_feature(f"{to_container}_contains_{liquid}",   1)
            if location: self.set_feature(f"{liquid}_at_{location}", 1)

        elif action_str.startswith("turn_on"):
            parts = _parse(action_str)
            tool  = parts[0]
            if   tool == "stove":   self.set_feature("stove_on",   1)
            elif tool == "sink":    self.set_feature("sink_on",    1)
            elif tool == "blender": self.set_feature("blender_on", 1)

        elif action_str.startswith("turn_off"):
            parts = _parse(action_str)
            tool  = parts[0]
            if   tool == "stove":   self.set_feature("stove_on",   0)
            elif tool == "sink":    self.set_feature("sink_on",    0)
            elif tool == "blender": self.set_feature("blender_on", 0)

        elif action_str.startswith("blend"):
            parts     = _parse(action_str)
            container = parts[0]
            location  = parts[1] if len(parts) > 1 else "blending_station"
            if self.get_feature(f"{container}_at_{location}") != 1:                     raise ValueError(f"Precondition failed: {container} not at {location}")
            if self.get_feature("blender_on") != 1:                                     raise ValueError("Precondition failed: blender not on")
            contained = [ing for ing in INGREDIENTS if ing != "mixture" and self.get_feature(f"{container}_contains_{ing}") == 1]
            if len(contained) < 1:                                                      raise ValueError(f"Precondition failed: blend requires >=1 ingredient in {container}")
            for ing in contained: self.set_feature(f"{container}_contains_{ing}", 0)
            self.set_feature(f"{container}_contains_mixture", 1)
            self.set_feature(f"mixture_at_{location}", 1)

        elif action_str.startswith("serve"):
            parts    = _parse(action_str)
            plate    = parts[0]
            location = parts[1] if len(parts) > 1 else "serving_station"
            if self.get_feature(f"{plate}_at_{location}") != 1:                         raise ValueError(f"Precondition failed: {plate} not at {location}")
            self.set_feature("plate_served", 1)

        elif action_str.startswith("wash"):
            parts    = _parse(action_str)
            item     = parts[0]
            location = parts[1] if len(parts) > 1 else "washing_station"
            if self.get_feature(f"{item}_at_{location}") != 1:                          raise ValueError(f"Precondition failed: {item} not at {location}")
            self.set_feature(f"{item}_washed", 1)

        else:                                                                           raise ValueError(f"Unknown action: {action_str}")

class RecipeGenerator:
    def __init__(self):
        self.demos   = []
        self.tracker = StateTracker()

    def _record_trajectory(self, actions):
        self.tracker.reset()
        trajectory = []
        for action in actions:
            state_vector = self.tracker.get_state_vector().tolist()
            trajectory.append({"state": state_vector, "action": action})
            try:
                self.tracker.apply_action(action)
            except ValueError as e:
                print(f"ERROR applying action '{action}': {e}")
                raise
        final_state_vector = self.tracker.get_state_vector().tolist()
        trajectory.append({"state": final_state_vector, "action": "stop"})
        self.demos.append(trajectory)


    def generate_grilled_steak(self): return [
            "transfer (pan, from=storage, to=cooking_station)",
            "load (meat, bowl, storage)",
            "move_container (bowl, from=storage, to=cooking_station)",
            "unload (meat, bowl, cooking_station)",
            "load (meat, pan, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pan, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (meat, pan, cooking_station)",
            "load (meat, bowl, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "unload (meat, bowl, plating_station)",
            "load (meat, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pan, from=cooking_station, to=washing_station)",
            "wash (pan, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_boiled_eggs(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "load (egg, bowl, storage)",
            "move_container (bowl, from=storage, to=cooking_station)",
            "unload (egg, bowl, cooking_station)",
            "load (egg, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (pot, from=cooking_station, to=plating_station)",
            "unload (egg, pot, plating_station)",
            "load (egg, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pot, from=plating_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=cooking_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_boiled_rice(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "load (rice, bowl, storage)",
            "move_container (bowl, from=storage, to=cooking_station)",
            "unload (rice, bowl, cooking_station)",
            "load (rice, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (rice, pot, cooking_station)",
            "load (rice, bowl, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "unload (rice, bowl, plating_station)",
            "load (rice, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pot, from=cooking_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_simple_salad(self): return [
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (lettuce, from=storage, to=prep_station)",
            "cut (lettuce, prep_station)",
            "load (lettuce, bowl, prep_station)",
            "transfer (onion, from=storage, to=prep_station)",
            "cut (onion, prep_station)",
            "load (onion, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=prep_station, to=plating_station)",
            "unload (mixture, bowl, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (plate, from=serving_station, to=washing_station)",
            "wash (plate, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_burger(self): return [
            "transfer (pan, from=storage, to=cooking_station)",
            "load (meat, bowl, storage)",
            "move_container (bowl, from=storage, to=cooking_station)",
            "unload (meat, bowl, cooking_station)",
            "load (meat, pan, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pan, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (meat, pan, cooking_station)",
            "load (meat, bowl, cooking_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "unload (meat, bowl, plating_station)",
            "load (meat, plate, plating_station)",
            "transfer (lettuce, from=storage, to=prep_station)",
            "move_container (bowl, from=plating_station, to=prep_station)",
            "cut (lettuce, prep_station)",
            "load (lettuce, bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=plating_station)",
            "unload (lettuce, bowl, plating_station)",
            "load (lettuce, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pan, from=cooking_station, to=washing_station)",
            "wash (pan, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)"]

    def generate_tomato_soup(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (tomato, from=storage, to=prep_station)",
            "cut (tomato, prep_station)",
            "load (tomato, bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (tomato, bowl, cooking_station)",
            "load (tomato, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (tomato, pot, cooking_station)",
            "load (tomato, bowl, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "unload (tomato, bowl, plating_station)",
            "load (tomato, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pot, from=cooking_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_tomato_onion_soup_v1(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (tomato, from=storage, to=prep_station)",
            "cut (tomato, prep_station)",
            "load (tomato, bowl, prep_station)",
            "transfer (onion, from=storage, to=prep_station)",
            "cut (onion, prep_station)",
            "load (onion, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (mixture, bowl, cooking_station)",
            "load (mixture, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (pot, from=cooking_station, to=plating_station)",
            "unload (mixture, pot, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (pot, from=plating_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=cooking_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_tomato_onion_soup_v2(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (onion, from=storage, to=prep_station)",
            "cut (onion, prep_station)",
            "load (onion, bowl, prep_station)",
            "transfer (tomato, from=storage, to=prep_station)",
            "cut (tomato, prep_station)",
            "load (tomato, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (mixture, bowl, cooking_station)",
            "load (mixture, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (pot, from=cooking_station, to=plating_station)",
            "unload (mixture, pot, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (pot, from=plating_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=cooking_station, to=washing_station)",
            "wash (bowl, washing_station)"]
    def generate_mushroom_soup(self): return [
            "transfer (pot, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (mushroom, from=storage, to=prep_station)",
            "cut (mushroom, prep_station)",
            "load (mushroom, bowl, prep_station)",
            "transfer (onion, from=storage, to=prep_station)",
            "cut (onion, prep_station)",
            "load (onion, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (mixture, bowl, cooking_station)",
            "load (mixture, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (pot, from=cooking_station, to=plating_station)",
            "unload (mixture, pot, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (pot, from=plating_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=cooking_station, to=washing_station)",
            "wash (bowl, washing_station)"]

    # new recipes
    def generate_seasoned_chicken(self):
        return [
            "transfer (pan, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (chicken, from=storage, to=prep_station)",
            "season (chicken, salt, prep_station)",
            "season (chicken, spice1, prep_station)",
            "load (chicken, bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (chicken, bowl, cooking_station)",
            "load (chicken, pan, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pan, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (chicken, pan, cooking_station)",
            "load (chicken, bowl, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "unload (chicken, bowl, plating_station)",
            "load (chicken, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pan, from=cooking_station, to=washing_station)",
            "wash (pan, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)",
        ]
    def generate_garlic_fish(self):
        return [
            "transfer (pan, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (fish, from=storage, to=prep_station)",
            "season (fish, garlic, prep_station)",
            "season (fish, spice2, prep_station)",
            "load (fish, bowl, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (fish, bowl, cooking_station)",
            "load (fish, pan, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pan, cooking_station)",
            "turn_off (stove, cooking_station)",
            "unload (fish, pan, cooking_station)",
            "load (fish, bowl, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=cooking_station, to=plating_station)",
            "unload (fish, bowl, plating_station)",
            "load (fish, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "transfer (pan, from=cooking_station, to=washing_station)",
            "wash (pan, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)",
        ]
    def generate_seasoned_mixture_soup(self):
        return [
            "transfer (pot, from=storage, to=cooking_station)",
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (tomato, from=storage, to=prep_station)",
            "cut (tomato, prep_station)",
            "load (tomato, bowl, prep_station)",
            "transfer (onion, from=storage, to=prep_station)",
            "cut (onion, prep_station)",
            "load (onion, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "season_container (bowl, salt, prep_station)",
            "season_container (bowl, spice1, prep_station)",
            "move_container (bowl, from=prep_station, to=cooking_station)",
            "unload (mixture, bowl, cooking_station)",
            "load (mixture, pot, cooking_station)",
            "turn_on (stove, cooking_station)",
            "cook_contents (pot, cooking_station)",
            "turn_off (stove, cooking_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (pot, from=cooking_station, to=plating_station)",
            "unload (mixture, pot, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (pot, from=plating_station, to=washing_station)",
            "wash (pot, washing_station)",
            "transfer (bowl, from=cooking_station, to=washing_station)",
            "wash (bowl, washing_station)",
        ]
    def generate_grated_cheese_salad(self):
        return [
            "transfer (bowl, from=storage, to=prep_station)",
            "transfer (lettuce, from=storage, to=prep_station)",
            "cut (lettuce, prep_station)",
            "load (lettuce, bowl, prep_station)",
            "transfer (cheese, from=storage, to=prep_station)",
            "grate (cheese, prep_station)",
            "load (cheese, bowl, prep_station)",
            "combine (bowl, prep_station)",
            "transfer (plate, from=storage, to=plating_station)",
            "move_container (bowl, from=prep_station, to=plating_station)",
            "unload (mixture, bowl, plating_station)",
            "load (mixture, plate, plating_station)",
            "move_container (plate, from=plating_station, to=serving_station)",
            "serve (plate, serving_station)",
            "move_container (plate, from=serving_station, to=washing_station)",
            "wash (plate, washing_station)",
            "transfer (bowl, from=plating_station, to=washing_station)",
            "wash (bowl, washing_station)",
        ]
    def generate_smoothie(self):
        return [
            "transfer (glass, from=storage, to=blending_station)",
            "transfer (measuring_cup, from=storage, to=blending_station)",
            "transfer (banana, from=storage, to=prep_station)",
            "cut (banana, prep_station)",
            "transfer (strawberries, from=storage, to=prep_station)",
            "cut (strawberries, prep_station)",
            "transfer (banana, from=prep_station, to=blending_station)",
            "transfer (strawberries, from=prep_station, to=blending_station)",
            "transfer (milk, from=storage, to=blending_station)",
            "load (milk, measuring_cup, blending_station)",
            "pour (milk, measuring_cup, glass, blending_station)",
            "load (banana, glass, blending_station)",
            "load (strawberries, glass, blending_station)",
            "turn_on (blender, blending_station)",
            "blend (glass, blending_station)",
            "turn_off (blender, blending_station)",
            "transfer (glass, from=blending_station, to=serving_station)",
            "serve (glass, serving_station)",
            "transfer (measuring_cup, from=blending_station, to=washing_station)",
            "wash (measuring_cup, washing_station)",
        ]
    def generate_yoghurt_smoothie(self):
        return [
            "transfer (glass, from=storage, to=blending_station)",
            "transfer (measuring_cup, from=storage, to=blending_station)",
            "transfer (banana, from=storage, to=prep_station)",
            "cut (banana, prep_station)",
            "transfer (banana, from=prep_station, to=blending_station)",
            "transfer (yoghurt, from=storage, to=blending_station)",
            "load (yoghurt, glass, blending_station)",
            "transfer (milk, from=storage, to=blending_station)",
            "load (milk, measuring_cup, blending_station)",
            "pour (milk, measuring_cup, glass, blending_station)",
            "load (banana, glass, blending_station)",
            "turn_on (blender, blending_station)",
            "blend (glass, blending_station)",
            "turn_off (blender, blending_station)",
            "transfer (glass, from=blending_station, to=serving_station)",
            "serve (glass, serving_station)",
            "transfer (measuring_cup, from=blending_station, to=washing_station)",
            "wash (measuring_cup, washing_station)",
        ]

    def generate_random_dataset(self, count):
        available_recipes = [self.generate_tomato_onion_soup_v1]
        print(f"Generating {count} PDDL-style demonstrations...")
        for i in range(count):
            recipe_func = random.choice(available_recipes)
            actions = recipe_func()
            self._record_trajectory(actions)

    def save_to_file(self):
        with open(OUTPUT_FILE, "w") as f:
            for demo in self.demos: f.write(json.dumps(demo) + "\n")
        print(f"Saved {len(self.demos)} trajectories to {OUTPUT_FILE}")
        print(f"State Vector Size: {self.tracker.n_features}")


# IRL

In [2]:
import numpy as np
import json

def load_demonstrations(filepath):
    """Load demonstrations from JSON file"""
    with open(filepath, 'r') as f: demo_lists = [json.loads(line) for line in f.read().strip().split('\n')]
    
    demonstrations = []
    all_actions = set()
    
    for demo in demo_lists:
        trajectory = [(tuple(step['state']), step['action']) for step in demo]
        demonstrations.append(trajectory)
        all_actions.update(step['action'] for step in demo)
    return demonstrations, sorted(list(all_actions))

def create_state_action_mappings(demonstrations, unique_actions):
    """Create bidirectional mappings between states/actions and indices"""
    unique_states = set()
    for trajectory in demonstrations: unique_states.update(state for state, _ in trajectory)
    
    state_to_idx = {state: idx for idx, state in enumerate(sorted(unique_states))}
    idx_to_state = {idx: state for state, idx in state_to_idx.items()}
    action_to_idx = {action: idx for idx, action in enumerate(unique_actions)}
    idx_to_action = {idx: action for action, idx in action_to_idx.items()}
    return state_to_idx, idx_to_state, action_to_idx, idx_to_action

def create_enhanced_feature_matrix(idx_to_state):
    """
    Create task-discriminative features for the expanded PDDL-style kitchen environment.

    State vector layout (371 dims):
      0   – 181 : Item locations        (26 items × 7 locations)
      182 – 301 : Containment           (6 containers × 20 ingredients)
      302 – 310 : Cut status            (9 cuttables)
      311        : Grated status        (cheese)
      312 – 320 : Cooked status         (9 cookables)
      321 – 340 : Seasoned status       (20 ingredients)
      341 – 366 : Washed status         (26 items)
      367        : stove_on
      368        : sink_on
      369        : blender_on
      370        : plate_served
    """
    # --- dimension constants (must match StateTracker) ---
    n_items       = len(ITEMS)        # 26
    n_locations   = len(LOCATIONS)    # 7
    n_containers  = len(CONTAINERS)   # 6
    n_ingredients = len(INGREDIENTS)  # 20

    # --- section offsets ---
    LOCATION_OFFSET    = 0
    CONTAINMENT_OFFSET = n_items * n_locations                              # 182
    CUT_OFFSET         = CONTAINMENT_OFFSET + n_containers * n_ingredients  # 302
    GRATED_OFFSET      = CUT_OFFSET  + len(CUTTABLES)                       # 311
    COOKED_OFFSET      = GRATED_OFFSET + len(GRATABLE)                      # 312
    SEASONED_OFFSET    = COOKED_OFFSET + len(COOKABLES)                     # 321
    WASHED_OFFSET      = SEASONED_OFFSET + n_ingredients                    # 341
    STOVE_IDX          = WASHED_OFFSET + n_items                            # 367
    SINK_IDX           = STOVE_IDX  + 1                                     # 368
    BLENDER_IDX        = SINK_IDX   + 1                                     # 369
    SERVED_IDX         = BLENDER_IDX + 1                                    # 370

    # --- item / container / location indices ---
    pot_idx    = CONTAINERS.index("pot")
    pan_idx    = CONTAINERS.index("pan")
    plate_idx  = CONTAINERS.index("plate")
    bowl_idx   = CONTAINERS.index("bowl")
    glass_idx  = CONTAINERS.index("glass")
    mc_idx     = CONTAINERS.index("measuring_cup")

    prep_idx   = LOCATIONS.index("prep_station")
    cook_idx   = LOCATIONS.index("cooking_station")
    plat_idx   = LOCATIONS.index("plating_station")
    serve_idx  = LOCATIONS.index("serving_station")
    wash_idx   = LOCATIONS.index("washing_station")
    blend_idx  = LOCATIONS.index("blending_station")

    # helper: index of item i at location j
    def iloc(item_i, loc_j): return LOCATION_OFFSET + item_i * n_locations + loc_j
    # helper: index of containment cont_i contains ing_j
    def cfeat(cont_i, ing_j): return CONTAINMENT_OFFSET + cont_i * n_ingredients + ing_j
    # item indices in ITEMS (ITEMS = CONTAINERS + INGREDIENTS)
    item_idx = {item: i for i, item in enumerate(ITEMS)}

    features = []

    for idx in range(len(idx_to_state)):
        state = np.array(idx_to_state[idx])
        feat  = []

        # === Location-based features (7 features) ===
        # Count items at each location
        for loc_j in range(n_locations):
            indices = [iloc(i, loc_j) for i in range(n_items)]
            feat.append(np.sum(state[indices]))

        # === Container-specific location features (16 features) ===
        feat.append(state[iloc(pot_idx,   cook_idx)])   # pot at cooking
        feat.append(state[iloc(pan_idx,   cook_idx)])   # pan at cooking
        feat.append(state[iloc(plate_idx, plat_idx)])   # plate at plating
        feat.append(state[iloc(plate_idx, serve_idx)])  # plate at serving
        feat.append(state[iloc(plate_idx, wash_idx)])   # plate at washing
        feat.append(state[iloc(bowl_idx,  prep_idx)])   # bowl at prep
        feat.append(state[iloc(bowl_idx,  cook_idx)])   # bowl at cooking
        feat.append(state[iloc(bowl_idx,  plat_idx)])   # bowl at plating
        feat.append(state[iloc(glass_idx, blend_idx)])  # glass at blending
        feat.append(state[iloc(mc_idx,    blend_idx)])  # measuring_cup at blending
        feat.append(state[iloc(glass_idx, serve_idx)])  # glass at serving
        # Containers at washing
        feat.append(state[iloc(pot_idx,  wash_idx)])
        feat.append(state[iloc(pan_idx,  wash_idx)])
        feat.append(state[iloc(bowl_idx, wash_idx)])
        feat.append(state[iloc(glass_idx, wash_idx)])
        feat.append(state[iloc(mc_idx,   wash_idx)])
        cooking_vessel_at_cook = (state[iloc(pot_idx, cook_idx)] + state[iloc(pan_idx, cook_idx)])
        feat.append(cooking_vessel_at_cook)             # any cooking vessel at cook
        items_at_prep = np.sum([state[iloc(i, prep_idx)] for i in range(n_items)])
        feat.append(items_at_prep)                      # items at prep station
        items_at_blend = np.sum([state[iloc(i, blend_idx)] for i in range(n_items)])
        feat.append(items_at_blend)                     # items at blending station

        # === Containment features (8 features) ===
        total_contained = np.sum(state[CONTAINMENT_OFFSET:CONTAINMENT_OFFSET + n_containers * n_ingredients])
        feat.append(total_contained)
        items_in_pot  = np.sum(state[cfeat(pot_idx,  0):cfeat(pot_idx,  0) + n_ingredients])
        items_in_pan  = np.sum(state[cfeat(pan_idx,  0):cfeat(pan_idx,  0) + n_ingredients])
        items_in_plate= np.sum(state[cfeat(plate_idx,0):cfeat(plate_idx,0) + n_ingredients])
        items_in_bowl = np.sum(state[cfeat(bowl_idx, 0):cfeat(bowl_idx, 0) + n_ingredients])
        items_in_glass= np.sum(state[cfeat(glass_idx,0):cfeat(glass_idx,0) + n_ingredients])
        items_in_mc   = np.sum(state[cfeat(mc_idx,   0):cfeat(mc_idx,   0) + n_ingredients])
        feat.append(items_in_pot)
        feat.append(items_in_pan)
        feat.append(items_in_plate)
        feat.append(items_in_bowl)
        feat.append(items_in_glass)
        feat.append(items_in_mc)
        # mixture present anywhere
        mixture_idx = INGREDIENTS.index("mixture")
        mixture_anywhere = np.sum([state[cfeat(c, mixture_idx)] for c in range(n_containers)])
        feat.append(mixture_anywhere)

        # === Processing state features ===
        # Cut status (9 features + summary)
        feat.extend(state[CUT_OFFSET:CUT_OFFSET + len(CUTTABLES)])
        total_cut = np.sum(state[CUT_OFFSET:CUT_OFFSET + len(CUTTABLES)])
        feat.append(total_cut)
        # Grated status (1 feature)
        total_grated = np.sum(state[GRATED_OFFSET:GRATED_OFFSET + len(GRATABLE)])
        feat.append(total_grated)
        # Cooked status (9 features + summary)
        feat.extend(state[COOKED_OFFSET:COOKED_OFFSET + len(COOKABLES)])
        total_cooked = np.sum(state[COOKED_OFFSET:COOKED_OFFSET + len(COOKABLES)])
        feat.append(total_cooked)
        # Seasoned status (summary only to keep dim manageable)
        total_seasoned = np.sum(state[SEASONED_OFFSET:SEASONED_OFFSET + n_ingredients])
        feat.append(total_seasoned)
        # Individual seasoning flags for key ingredients
        for ing_name in ["chicken", "fish", "mixture"]:
            ing_j = INGREDIENTS.index(ing_name)
            feat.append(state[SEASONED_OFFSET + ing_j])
        # Washed status (summary)
        total_washed = np.sum(state[WASHED_OFFSET:WASHED_OFFSET + n_items])
        feat.append(total_washed)

        # === Tool usage features (6 features) ===
        stove_on   = state[STOVE_IDX]
        sink_on    = state[SINK_IDX]
        blender_on = state[BLENDER_IDX]
        feat.append(stove_on)
        feat.append(stove_on * cooking_vessel_at_cook)
        feat.append(stove_on * total_contained)
        feat.append(sink_on)
        feat.append(blender_on)
        feat.append(blender_on * items_in_glass)

        # === Workflow progress features (8 features) ===
        served = state[SERVED_IDX]
        feat.append(int(total_cut > 0))             # prep started
        feat.append(int(total_cooked > 0))          # cooking complete
        feat.append(int(items_in_plate > 0))        # plating in progress
        feat.append(int(items_in_glass > 0))        # blending in progress
        feat.append(served)                         # served
        feat.append(int(total_washed > 0))          # cleanup in progress
        feat.append(int(total_grated > 0))          # grating done
        feat.append(int(total_seasoned > 0))        # seasoning applied
        # Full workflow complete
        workflow_complete = int(total_cut > 0) * int(total_cooked > 0) * int(served) * int(total_washed > 0)
        feat.append(workflow_complete)

        features.append(feat)

    feature_matrix = np.array(features, dtype=float)
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Feature stats - min: {feature_matrix.min(axis=0)[:10]}, max: {feature_matrix.max(axis=0)[:10]}")
    return feature_matrix

def max_ent_irl(demonstrations, feature_matrix, state_to_idx, action_to_idx, n_iterations=100, learning_rate=0.05, temperature=2.0, gamma=0.9):
    """
    Maximum Entropy IRL with improved convergence.
    """
    n_states   = feature_matrix.shape[0]
    n_features = feature_matrix.shape[1]
    n_actions  = len(action_to_idx)
    
    print(f"\n=== Starting MaxEnt IRL ===")
    print(f"States: {n_states}, Actions: {n_actions}, Features: {n_features}")
    
    # Initialize reward weights
    reward_weights = np.random.randn(n_features) * 0.1
    
    # Compute empirical feature expectations
    empirical_feature_expectations = np.zeros(n_features)
    state_action_pairs = set()
    transition_model = {}
    
    for trajectory in demonstrations:
        traj_features = np.zeros(n_features)
        for i, (state, action) in enumerate(trajectory):
            s_idx = state_to_idx[state]
            traj_features += (gamma ** i) * feature_matrix[s_idx]
        empirical_feature_expectations += traj_features
    empirical_feature_expectations /= len(demonstrations)
    
    # Build state-action pairs and transition model
    for trajectory in demonstrations:
        for i in range(len(trajectory)):
            state, action = trajectory[i]
            s_idx, a_idx = state_to_idx[state], action_to_idx[action]
            state_action_pairs.add((s_idx, a_idx))
            if i < len(trajectory) - 1: transition_model[(s_idx, a_idx)] = state_to_idx[trajectory[i + 1][0]]
    
    print(f"Found {len(state_action_pairs)} unique (state, action) pairs")
    print(f"Empirical feature expectation norm: {np.linalg.norm(empirical_feature_expectations):.4f}")
    
    # Gradient descent with momentum
    best_diff   = float('inf')
    patience    = 0
    best_weights = reward_weights.copy()
    momentum    = np.zeros(n_features)
    momentum_beta = 0.9
    
    for iteration in range(n_iterations):
        current_lr = learning_rate * (0.95 ** (patience // 5))
        rewards    = feature_matrix @ reward_weights
        
        # Value iteration
        q_values = {}
        values   = rewards.copy()
        
        for _ in range(30):
            new_values = rewards.copy()
            for (s_idx, a_idx) in state_action_pairs:
                if (s_idx, a_idx) in transition_model:  q_values[(s_idx, a_idx)] = rewards[s_idx] + gamma * values[transition_model[(s_idx, a_idx)]]
                else:                                    q_values[(s_idx, a_idx)] = rewards[s_idx]
            for s_idx in range(n_states):
                available_qs = [q_values[(s_idx, a_idx)] for a_idx in range(n_actions) if (s_idx, a_idx) in q_values]
                if available_qs:
                    available_qs = np.array(available_qs)
                    max_q        = np.max(available_qs)
                    new_values[s_idx] = max_q + temperature * np.log(np.sum(np.exp((available_qs - max_q) / temperature)))
            if np.max(np.abs(new_values - values)) < 1e-6: break
            values = new_values
        
        # Compute soft policy
        policy = {}
        for s_idx in range(n_states):
            state_qs = [(a_idx, q_values[(s_idx, a_idx)]) for a_idx in range(n_actions) if (s_idx, a_idx) in q_values]
            if state_qs:
                state_actions, state_q_vals = zip(*state_qs)
                state_q_vals = np.array(state_q_vals)
                max_q        = np.max(state_q_vals)
                probs        = np.exp((state_q_vals - max_q) / temperature)
                probs       /= np.sum(probs)
                for a_idx, prob in zip(state_actions, probs): policy[(s_idx, a_idx)] = prob
        
        # Monte Carlo feature expectations
        expected_feature_counts = np.zeros(n_features)
        for _ in range(100):
            s_idx        = state_to_idx[demonstrations[np.random.randint(len(demonstrations))][0][0]]
            traj_features = np.zeros(n_features)
            
            for t in range(20):
                traj_features += (gamma ** t) * feature_matrix[s_idx]
                available = [(a_idx, policy[(s_idx, a_idx)]) for a_idx in range(n_actions) if (s_idx, a_idx) in policy]
                if not available: break
                
                actions_list, probs = zip(*available)
                probs  = np.array(probs) / np.sum(probs)
                a_idx  = np.random.choice(actions_list, p=probs)
                
                if (s_idx, a_idx) not in transition_model: break
                s_idx = transition_model[(s_idx, a_idx)]
            expected_feature_counts += traj_features
        expected_feature_counts /= 100
        
        # Gradient update
        gradient  = empirical_feature_expectations - expected_feature_counts
        gradient -= 0.01 * reward_weights
        grad_norm = np.linalg.norm(gradient)
        momentum  = momentum_beta * momentum + (1 - momentum_beta) * gradient
        reward_weights += current_lr * momentum
        
        if grad_norm < best_diff:
            best_diff    = grad_norm
            best_weights = reward_weights.copy()
            patience     = 0
        else: patience += 1
        
        if iteration % 10 == 0:
            print(f"Iteration {iteration}: Gradient norm = {grad_norm:.6f}, Reward range = [{np.min(reward_weights):.2f}, {np.max(reward_weights):.2f}], Best = {best_diff:.6f}, LR = {current_lr:.4f}")
        if grad_norm < 0.01 or patience > 20:
            print(f"Converged at iteration {iteration}")
            break
    print(f"Final best gradient norm: {best_diff:.6f}")
    recovered_rewards = feature_matrix @ best_weights
    return best_weights, recovered_rewards

def predict_action(state, reward_weights, feature_matrix, state_to_idx, action_to_idx, idx_to_action, transition_model, temperature=2.0, gamma=0.9):
    """Predict most likely action from a given state"""
    if state not in state_to_idx: return None, {}
    s_idx   = state_to_idx[state]
    rewards = feature_matrix @ reward_weights
    valid_actions  = []
    q_values_list  = []
    
    for a_idx in range(len(action_to_idx)):
        if (s_idx, a_idx) in transition_model:
            s_next_idx = transition_model[(s_idx, a_idx)]
            q_value    = rewards[s_idx] + gamma * rewards[s_next_idx]
            valid_actions.append(a_idx)
            q_values_list.append(q_value)
    
    if not valid_actions: return None, {}
    
    q_values_array     = np.array(q_values_list)
    max_q              = np.max(q_values_array)
    action_probs_array = np.exp((q_values_array - max_q) / temperature)
    action_probs_array /= np.sum(action_probs_array)
    
    action_probs    = {idx_to_action[a_idx]: prob for a_idx, prob in zip(valid_actions, action_probs_array)}
    best_action_idx = valid_actions[np.argmax(action_probs_array)]
    return idx_to_action[best_action_idx], action_probs


# Pref

In [12]:
import numpy as np
from typing import List, Dict, Tuple, Set, Optional, Callable
from collections import defaultdict

class PreferenceBasedRecipeModifier:
    """
    Modifies recipes to show different behavioral preferences.
    
    Key Principle: Preferences change HOW tasks are done (timing/sequence),
    not WHETHER they're done.
    """
    
    def __init__(self, state_tracker):
        self.tracker = state_tracker
        self.preference_names = [
            "plating_ingredients",
            "washing_plates", 
            "delivering_dishes",
            "chopping_ingredients",
            "potting_rice",
            "grilling_meat_mushroom",
            "taking_mushroom_from_dispenser",
            "taking_rice_from_dispenser",
            "taking_meat_from_dispenser"
        ]
    
    # ===== RECIPE MODIFICATION FUNCTIONS =====    
    def modify_recipe(self, actions: List[str], preferences: Dict[str, bool]) -> List[str]:
        """
        Modify recipe according to preferences.
        
        Args:
            actions: Base recipe action sequence
            preferences: Dict of {preference_name: prefer_bool}
        
        Returns:
            Modified action sequence
        """
        modified = actions.copy()
        # Apply each preference modification
        if "washing_plates" in preferences:
            modified = self._modify_washing(modified, preferences["washing_plates"])
        if "chopping_ingredients" in preferences:
            modified = self._modify_chopping(modified, preferences["chopping_ingredients"])
        if "plating_ingredients" in preferences:
            modified = self._modify_plating(modified, preferences["plating_ingredients"])
        if "delivering_dishes" in preferences:
            modified = self._modify_delivery(modified, preferences["delivering_dishes"])
        if "potting_rice" in preferences:
            modified = self._modify_potting(modified, preferences["potting_rice"])
        if "grilling_meat_mushroom" in preferences:
            modified = self._modify_grilling(modified, preferences["grilling_meat_mushroom"])
        if "taking_mushroom_from_dispenser" in preferences:
            modified = self._modify_mushroom_retrieval(modified, preferences["taking_mushroom_from_dispenser"])
        if "taking_rice_from_dispenser" in preferences:
            modified = self._modify_rice_retrieval(modified, preferences["taking_rice_from_dispenser"])
        if "taking_meat_from_dispenser" in preferences:
            modified = self._modify_meat_retrieval(modified, preferences["taking_meat_from_dispenser"])
        return modified
    
    def _modify_washing(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: wash immediately. Prefer=False: batch at end"""
        if prefer:
            return actions  # Already immediate in base recipes
        
        # Batch washing at end
        modified = []
        washing_actions = [] 
        for action in actions:
            if "wash (" in action:  washing_actions.append(action)
            else:                   modified.append(action)
        
        # Add all washing after serving
        return modified + washing_actions
    
    def _modify_chopping(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: chop early. Prefer=False: chop just-in-time"""
        if prefer:
            return actions  # Already early in base recipes
        # Delay chopping until right before use
        modified = []
        cut_actions = {}  # item -> cut action
        for action in actions:
            if action.startswith("cut ("):
                # Extract item name
                item = action.split("(")[1].split(",")[0].strip()
                cut_actions[item] = action
            else:
                # Check if this action needs a cut item
                needs_cutting = False
                for item in cut_actions.keys():
                    if item in action and ("load (" in action or "combine (" in action):
                        # Insert cut right before this action
                        modified.append(cut_actions.pop(item))
                        needs_cutting = True
                        break
                
                modified.append(action)
        
        return modified
    
    def _modify_plating(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: plate immediately. Prefer=False: delay plating"""
        if prefer:
            return actions
        # Delay plating actions
        modified = []
        plating_actions = []
        for action in actions:
            if "plating_station" in action and action.startswith("load ("):     plating_actions.append(action)
            else:                                                               modified.append(action)
        
        # Insert plating right before serving
        serve_idx = next((i for i, a in enumerate(modified) if "serve (" in a), len(modified))
        return modified[:serve_idx] + plating_actions + modified[serve_idx:]
    
    def _modify_delivery(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: serve quickly. Prefer=False: inspect/delay"""
        # This is more subtle - for now, keep as-is
        return actions
    
    def _modify_potting(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: pot early. Prefer=False: pot just before cooking"""
        if prefer:
            return actions
        # Delay potting
        modified = []
        pot_loads = []
        for action in actions:
            if action.startswith('load (') and 'pot' in action:
                pot_loads.append(action)
            else:
                modified.append(action)
                # Add pot loading right before cooking
                if "turn_on (stove" in action and pot_loads:
                    modified = modified[:-1] + pot_loads + [modified[-1]]
                    pot_loads = []
        
        return modified
    
    def _modify_grilling(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: use pan. Prefer=False: use pot"""
        if prefer:
            # Substitute pot with pan where appropriate
            return [a.replace("pot", "pan") if ("cook_contents" in a or a.startswith("cook (")) else a for a in actions]
        return actions
    
    def _modify_mushroom_retrieval(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: get early. Prefer=False: get late"""
        return self._modify_retrieval(actions, prefer, "mushroom")
    
    def _modify_rice_retrieval(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: get early. Prefer=False: get late"""
        return self._modify_retrieval(actions, prefer, "rice")
    
    def _modify_meat_retrieval(self, actions: List[str], prefer: bool) -> List[str]:
        """Prefer=True: get early. Prefer=False: get late"""
        return self._modify_retrieval(actions, prefer, "meat")
    
    def _modify_retrieval(self, actions: List[str], prefer: bool, item: str) -> List[str]:
        """Generic retrieval modifier"""
        if prefer:
            return actions
        
        # Delay retrieval
        modified = []
        retrieval_action = None
        for action in actions:
            if item in action and "transfer (" in action and "storage" in action:
                retrieval_action = action
            else:
                modified.append(action)
                # Insert right before use
                if retrieval_action and item in action and (action.startswith("cut (") or action.startswith("load (")):
                    modified.insert(-1, retrieval_action)
                    retrieval_action = None
        return modified
    
    # ===== PREFERENCE DETECTION =====    
    def detect_preferences(self, actions: List[str]) -> Dict[str, bool]:
        """
        Detect which preferences are shown in a demonstration.
        
        Returns:
            Dict of {preference_name: detected_bool}
        """
        return {
            "plating_ingredients": self._detect_plating(actions),
            "washing_plates": self._detect_washing(actions),
            "delivering_dishes": self._detect_delivery(actions),
            "chopping_ingredients": self._detect_chopping(actions),
            "potting_rice": self._detect_potting(actions),
            "grilling_meat_mushroom": self._detect_grilling(actions),
            "taking_mushroom_from_dispenser": self._detect_item_retrieval(actions, "mushroom"),
            "taking_rice_from_dispenser": self._detect_item_retrieval(actions, "rice"),
            "taking_meat_from_dispenser": self._detect_item_retrieval(actions, "meat"),
        }
    
    def _detect_washing(self, actions: List[str]) -> bool:
        """Spread out = prefer, batched = don't prefer"""
        wash_indices = [i for i, a in enumerate(actions) if "wash (" in a]
        if len(wash_indices) < 2: return True
        
        gaps = [wash_indices[i+1] - wash_indices[i] for i in range(len(wash_indices)-1)]
        avg_gap = np.mean(gaps) if gaps else 0
        return avg_gap > 3
    
    def _detect_chopping(self, actions: List[str]) -> bool:
        """Early = prefer, late = don't prefer"""
        cut_indices = [i for i, a in enumerate(actions) if a.startswith("cut (")]
        if not cut_indices: return True
        
        avg_position = np.mean(cut_indices) / len(actions)
        return avg_position < 0.4
    
    def _detect_plating(self, actions: List[str]) -> bool:
        """Early = prefer, late = don't prefer"""
        plate_loads = [i for i, a in enumerate(actions) if "plating_station" in a and "load (" in a]
        if not plate_loads: return True
        
        avg_position = np.mean(plate_loads) / len(actions)
        return avg_position < 0.7
    
    def _detect_delivery(self, actions: List[str]) -> bool:
        """Quick = prefer"""
        return True  # Simplified
    
    def _detect_potting(self, actions: List[str]) -> bool:
        """Early potting = prefer"""
        pot_loads = [i for i, a in enumerate(actions) if "load (" in a and "pot" in a]
        cook_starts = [i for i, a in enumerate(actions) if "turn_on (stove" in a]
        if not pot_loads or not cook_starts: return True
        
        return min(pot_loads) < min(cook_starts) - 2
    
    def _detect_grilling(self, actions: List[str]) -> bool:
        """Pan usage = prefer"""
        pan_uses = sum(1 for a in actions if "pan" in a and ("cook" in a or "load" in a))
        pot_uses = sum(1 for a in actions if "pot" in a and ("cook" in a or "load" in a))
        return pan_uses > pot_uses
    
    def _detect_item_retrieval(self, actions: List[str], item: str) -> bool:
        """Early retrieval = prefer"""
        get_idx = next((i for i, a in enumerate(actions) if item in a and "transfer (" in a and "storage" in a), -1)
        use_idx = next((i for i, a in enumerate(actions) if item in a and ("cut (" in a or "load (" in a)), -1)
        if get_idx == -1 or use_idx == -1: return True
        
        return get_idx < use_idx - 3

class AdaptiveIRL:
    """
    Adaptive IRL system with prediction-error-based novelty detection.
    
    The robot watches a full demonstration step-by-step, predicting each action.
    If the error rate exceeds the novelty threshold, the demo is novel:
    store it, retrain, and report before/after accuracy.
    """

    def __init__(self, novelty_threshold=0.20):
        """
        Args:
            novelty_threshold: error rate above which a demo is considered novel.
                               0.20 means >20% mispredictions → new recipe/preference.
        """
        self.novelty_threshold = novelty_threshold
        self.modifier = PreferenceBasedRecipeModifier(StateTracker())
        self.train_demos = []
        self._model = None  # (weights, feat_mat, s2i, a2i, i2a, trans)

    # ── public API ────────────────────────────────────────────────────────

    def make_demo(self, recipe_func, preferences=None):
        """Generate a (state, action) trajectory with optional preference shaping."""
        tracker = StateTracker()
        base_actions = recipe_func()
        actions = self.modifier.modify_recipe(base_actions, preferences) if preferences else base_actions
        tracker.reset()
        traj = []
        for action in actions:
            traj.append((tuple(tracker.get_state_vector().tolist()), action))
            try: tracker.apply_action(action)
            except Exception as e: print(f"  Warning: {e}")
        traj.append((tuple(tracker.get_state_vector().tolist()), "stop"))
        return traj

    def initialize(self, recipe_func, preferences=None):
        """Train on a single base demonstration."""
        demo = self.make_demo(recipe_func, preferences)
        self.train_demos = [demo]
        self._retrain()
        print(f"Initialized | States: {len(self._model[2])} | Actions: {len(self._model[3])}\n")

    def observe(self, demo, label="", show_steps=False):
        """
        Watch a full demonstration step-by-step, predicting each action.
        
        Novelty is determined by prediction error rate AFTER watching the full demo,
        matching the scenario: 'it knows it's a new recipe because it mispredicts actions'.
        
        If novel:  store demo, retrain, report before AND after accuracy.
        If known:  report accuracy only (no retraining needed).
        
        Returns (accuracy_before, accuracy_after) — after is None if not novel.
        """
        tag = label or "Demo"
        print(f"\n{'─'*60}")
        print(f"{tag}")

        # Step through the full demo, predicting each action
        weights, feat_mat, s2i, a2i, i2a, trans = self._model
        correct, total, unknown = 0, 0, 0

        for i, (state, true_action) in enumerate(demo[:-1]):
            total += 1
            if state not in s2i:
                unknown += 1
                if show_steps:      print(f"    Step {i+1:>2}: {true_action}  →  [unknown state] ✗")
                continue
            pred, _ = predict_action(state, weights, feat_mat, s2i, a2i, i2a, trans)
            ok = pred == true_action
            if ok: correct += 1
            if show_steps:          print(f"    Step {i+1:>2}: {true_action} \n Step {pred} {'✓' if ok else '✗'}")

        acc_before = correct / total if total > 0 else 0
        error_rate = 1.0 - acc_before
        # Unknown states count as errors for novelty purposes
        effective_error_rate = (total - correct) / total if total > 0 else 1.0
        is_novel = effective_error_rate > self.novelty_threshold

        print(f"  Accuracy (before): {acc_before:.1%} ({correct}/{total})  "
              f"Unknown states: {unknown}  Error rate: {effective_error_rate:.1%}")
        print(f"  Novelty threshold: >{self.novelty_threshold:.0%} errors  →  "
              f"{'NOVEL — retraining' if is_novel else 'Known — no retraining'}")

        if not is_novel:
            return acc_before, None

        # Novel: store and retrain
        self.train_demos.append(demo)
        self._retrain()
        acc_after = self._evaluate(demo, show_steps=False)
        print(f"  Accuracy (after):  {acc_after:.1%}  Improvement: {acc_after - acc_before:+.1%}")
        return acc_before, acc_after

    def evaluate(self, demo, label="", show_steps=False):
        """Evaluate current model on a demo (for final sweep)."""
        acc = self._evaluate(demo, show_steps=show_steps)
        if label:
            print(f"{label}Accuracy: {acc:.1%}")
        return acc

    # ── internals ─────────────────────────────────────────────────────────

    def _retrain(self):
        demos = self.train_demos
        all_actions = sorted({a for traj in demos for _, a in traj})
        s2i, i2s, a2i, i2a = create_state_action_mappings(demos, all_actions)
        feat_mat = create_enhanced_feature_matrix(i2s)
        weights, _ = max_ent_irl(demos, feat_mat, s2i, a2i, n_iterations=100)
        trans = {}
        for traj in demos:
            for i in range(len(traj) - 1):
                s, a = traj[i]; ns = traj[i+1][0]
                if s in s2i and ns in s2i and a in a2i:
                    trans[(s2i[s], a2i[a])] = s2i[ns]
        self._model = (weights, feat_mat, s2i, a2i, i2a, trans)

    def _evaluate(self, demo, show_steps=False):
        weights, feat_mat, s2i, a2i, i2a, trans = self._model
        correct, total = 0, 0
        for i, (state, true_action) in enumerate(demo[:-1]):
            total += 1
            if state not in s2i: continue
            pred, _ = predict_action(state, weights, feat_mat, s2i, a2i, i2a, trans)
            ok = pred == true_action
            if ok: correct += 1
            if show_steps:          print(f"    Step {i+1:>2}: {true_action} \n     Step {pred} {'✓' if ok else '✗'}")
        return correct / total if total > 0 else 0

In [None]:
SHOW_STEPS = True  # flip to True for step-by-step predictions
gen = RecipeGenerator()
system = AdaptiveIRL(novelty_threshold=0.20)

# Base preferences (matching image: 3 categories, 9 preferences)
base_prefs = {
    "plating_ingredients":            True,
    "washing_plates":                 False,
    "delivering_dishes":              True,
    "chopping_ingredients":           True,
    "potting_rice":                   False,
    "grilling_meat_mushroom":         False,
    "taking_mushroom_from_dispenser": False,
    "taking_rice_from_dispenser":     False,
    "taking_meat_from_dispenser":     False,
}

# ── PHASE 1: Initialize on one recipe ────────────────────────────────
print("="*60, "PHASE 1: INITIAL TRAINING", "="*60)
system.initialize(gen.generate_tomato_onion_soup_v1, base_prefs)


# ── PHASE 2: Observe new demonstrations ──────────────────────────────
print("="*60, "PHASE 2: OBSERVE & ADAPT", "="*60)

# Same recipe, same prefs → should skip
same_demo = system.make_demo(gen.generate_tomato_onion_soup_v1, base_prefs)
# for i, demo in enumerate(system.train_demos):
#     print(f"\nDemo {i+1}:")
#     system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
system.observe(same_demo, label="Same recipe, same prefs (expect: skip)")
# for i, demo in enumerate(system.train_demos):
#     print(f"\nDemo {i+1}:")
#     system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
print("\n\n\n\n\n", "Same recipe, same prefs", "\n\n\n\n\n")

# Same recipe, different prefs → novel
diff_pref_demo = system.make_demo(gen.generate_tomato_onion_soup_v1, {**base_prefs,     "washing_plates":  True,        "chopping_ingredients":  False,         "plating_ingredients":  False,})
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
system.observe(diff_pref_demo, label="Same recipe, different prefs (expect: novel)")
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
print("\n\n\n\n\n", "Same recipe, different prefs", "\n\n\n\n\n")

# Different recipe, same prefs → novel
v2_demo = system.make_demo(gen.generate_tomato_onion_soup_v2, base_prefs)
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
system.observe(v2_demo, label="Different recipe – Tomato Onion Soup v2 (expect: novel)")
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
print("\n\n\n\n\n", "Different recipe, same prefs", "\n\n\n\n\n")

# Different recipe, different prefs → novel
mushroom_demo = system.make_demo(gen.generate_mushroom_soup, {**base_prefs,     "taking_mushroom_from_dispenser": True,     "grilling_meat_mushroom":         True,})
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
system.observe(mushroom_demo, label="Different recipe + different prefs – Mushroom Soup (expect: novel)")
for i, demo in enumerate(system.train_demos):
    print(f"\nDemo {i+1}:")
    system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)
print("\n\n\n\n\n", "Different recipe, different prefs", "\n\n\n\n\n")

# ── PHASE 3: Final sweep on all collected demos ───────────────────────
print(f"\n{'='*60} PHASE 3: FINAL EVALUATION ({len(system.train_demos)} demos in training set) {'='*60}")
# for i, demo in enumerate(system.train_demos):
#     print(f"\nDemo {i+1}:")
#     system.evaluate(demo, label="  ", show_steps=SHOW_STEPS)

Total state features: 356
Feature dimensions: Locations=175, Containment=114, Cut=9, Grated=1, Cooked=9, Seasoned=19, Washed=25, Tools=3, Served=1
Total state features: 356
Feature dimensions: Locations=175, Containment=114, Cut=9, Grated=1, Cooked=9, Seasoned=19, Washed=25, Tools=3, Served=1
Total state features: 356
Feature dimensions: Locations=175, Containment=114, Cut=9, Grated=1, Cooked=9, Seasoned=19, Washed=25, Tools=3, Served=1
Feature matrix shape: (26, 75)
Feature stats - min: [20.  0.  0.  0.  0.  0.  0.  0.  0.  0.], max: [25.  4.  3.  3.  2.  2.  0.  1.  0.  1.]

=== Starting MaxEnt IRL ===
States: 26, Actions: 26, Features: 75
Found 26 unique (state, action) pairs
Empirical feature expectation norm: 208.0753
Iteration 0: Gradient norm = 11.815292, Reward range = [-0.23, 0.22], Best = 11.815292, LR = 0.0500
Iteration 10: Gradient norm = 11.790857, Reward range = [-0.22, 2.88], Best = 11.790857, LR = 0.0500
Iteration 20: Gradient norm = 11.743955, Reward range = [-0.22, 7.

# Training and Testing

In [5]:
# Train on Recipe 1 only
print("PHASE 1: INITIAL TRAINING")
# Generate training data from Recipe 1
gen = RecipeGenerator()
gen.generate_random_dataset(1)  # Generates tomato_onion_soup_v1

# Convert to IRL format
train_demonstrations = []
all_actions = set()
for demo in gen.demos:
    trajectory = [(tuple(step['state']), step['action']) for step in demo]
    train_demonstrations.append(trajectory)
    all_actions.update(step['action'] for step in demo)
unique_actions = sorted(list(all_actions))
print(f"\nTraining Recipe: Tomato-Onion Soup v1")
print(f"  Demonstrations: {len(train_demonstrations)}")
print(f"  Unique actions: {len(unique_actions)}")

# Create mappings
state_to_idx, idx_to_state, action_to_idx, idx_to_action = create_state_action_mappings(train_demonstrations, unique_actions)
# Extract features
feature_matrix = create_enhanced_feature_matrix(idx_to_state)
# Train IRL
reward_weights, recovered_rewards = max_ent_irl(train_demonstrations, feature_matrix, state_to_idx, action_to_idx, n_iterations=100)
# Build transition model
transition_model = {}
for trajectory in train_demonstrations:
    for i in range(len(trajectory) - 1):
        state, action = trajectory[i]
        next_state = trajectory[i + 1][0]
        s_idx = state_to_idx[state]
        a_idx = action_to_idx[action]
        transition_model[(s_idx, a_idx)] = state_to_idx[next_state]
print("\n Initial training complete")
print(f"  Training set size: {len(train_demonstrations)} recipe")
print("\n\n\n\n\n")




# Show model new recipes, collect failures, then retrain
showPred = True
print("PHASE 2: ADAPTIVE LEARNING")
# Define new recipes to learn
new_recipes = [
    (gen.generate_tomato_onion_soup_v2, "Tomato-Onion Soup v2"),
    (gen.generate_mushroom_soup, "Mushroom Soup"),
    (gen.generate_tomato_soup, "Tomato Soup"),
]
# Test each new recipe and collect demonstrations
new_demonstrations = []

for recipe_func, recipe_name in new_recipes:
    print(f"\n\nRecipe: {recipe_name}")
    # Generate trajectory
    test_gen = RecipeGenerator()
    test_gen._record_trajectory(recipe_func())
    test_trajectory = [(tuple(step['state']), step['action']) for step in test_gen.demos[0]]
    # Test with current model
    correct = 0
    total = len(test_trajectory) - 1
    unknown_states = 0
    errors = []
    
    for i, (state, true_action) in enumerate(test_trajectory[:-1]):
        if state not in state_to_idx:
            unknown_states += 1
            errors.append(f"Step {i+1}: Unknown state")
            continue
        predicted_action, action_probs = predict_action(state, reward_weights, feature_matrix, state_to_idx, action_to_idx, idx_to_action, transition_model)
        
        is_correct = (predicted_action == true_action)
        if is_correct:  correct += 1
        else:           errors.append(f"Step {i+1}: Predicted '{predicted_action}' but true was '{true_action}'")
        if showPred: print(f"Step {i+1:>2}: {true_action}\n   Pred: {predicted_action} {'✓' if is_correct else '✗'}")
    accuracy = correct / total if total > 0 else 0
    print(f"Performance BEFORE learning:")
    print(f"  Accuracy: {accuracy:.1%} ({correct}/{total})")
    print(f"  Unknown states: {unknown_states}")
    print(f"  Errors: {len(errors)}")

    # Store this demonstration for retraining
    if accuracy < 0.95:
        new_demonstrations.append(test_trajectory)
        print(f"✓ Stored {recipe_name} for retraining")
print(f"Collected {len(new_demonstrations)} new recipe demonstrations")
print("\n\n\n\n\n")




# Retrain IRL on expanded dataset
print("PHASE 3: RETRAINING ON EXPANDED DATASET")
# Add new demonstrations to training set
print(f"Before: {len(train_demonstrations)} demonstrations")
train_demonstrations.extend(new_demonstrations)
print(f"After: {len(train_demonstrations)} demonstrations")
print(f"Added: {len(new_demonstrations)} new recipes\n")

# Collect all unique actions
all_actions = set()
for traj in train_demonstrations: all_actions.update(action for _, action in traj)
unique_actions = sorted(list(all_actions))
print(f"Total unique actions: {len(unique_actions)}")

# Recreate mappings with expanded data
state_to_idx, idx_to_state, action_to_idx, idx_to_action = create_state_action_mappings(train_demonstrations, unique_actions)
# Extract features for all states
feature_matrix = create_enhanced_feature_matrix(idx_to_state)
# Retrain IRL
print("\nRetraining IRL")
reward_weights, recovered_rewards = max_ent_irl(train_demonstrations, feature_matrix, state_to_idx, action_to_idx, n_iterations=100)

# Rebuild transition model
transition_model = {}
for trajectory in train_demonstrations:
    for i in range(len(trajectory) - 1):
        state, action = trajectory[i]
        next_state = trajectory[i + 1][0]
        s_idx = state_to_idx[state]
        a_idx = action_to_idx[action]
        transition_model[(s_idx, a_idx)] = state_to_idx[next_state]

print("\n✓ Retraining complete")
print(f"  State space size: {len(state_to_idx)}")
print(f"  Action space size: {len(action_to_idx)}")
print("\n\n\n\n\n")




# Test retrained model on all 4 recipes
print("PHASE 4: FINAL EVALUATION")
# Recipe names for display
print(f"\nEvaluating retrained model on {len(train_demonstrations)} training demonstrations\n")
evaluation_results = []

for idx, test_trajectory in enumerate(train_demonstrations):
    # Test with retrained model
    correct = 0
    total = len(test_trajectory) - 1
    unknown_states = 0
    
    for i, (state, true_action) in enumerate(test_trajectory[:-1]):
        if state not in state_to_idx:
            unknown_states += 1
            continue
        predicted_action, action_probs = predict_action(state, reward_weights, feature_matrix, state_to_idx, action_to_idx, idx_to_action, transition_model)
        is_correct = (predicted_action == true_action)
        if is_correct: correct += 1
        if showPred: print(f"Step {i+1:>2}: {true_action}\n   Pred: {predicted_action} {'✓' if is_correct else '✗'}")
    accuracy = correct / total if total > 0 else 0
    print("\n\n")
    evaluation_results.append({'idx': idx, 'accuracy': accuracy, 'correct': correct, 'total': total, 'unknown_states': unknown_states})

# Print results
print("\n\nFINAL RESULTS")
for result in evaluation_results:
    status = "✓" if result['accuracy'] > 0.9 else "✗"
    print(f"{status} {result['idx']} - Accuracy: {result['accuracy']:5.1%}  ({result['correct']}/{result['total']})  Unknown: {result['unknown_states']}")

avg_accuracy = sum(r['accuracy'] for r in evaluation_results) / len(evaluation_results)
print(f"Average Accuracy: {avg_accuracy:.1%}")

PHASE 1: INITIAL TRAINING
Total state features: 356
Feature dimensions: Locations=175, Containment=114, Cut=9, Grated=1, Cooked=9, Seasoned=19, Washed=25, Tools=3, Served=1
Generating 1 PDDL-style demonstrations...

Training Recipe: Tomato-Onion Soup v1
  Demonstrations: 1
  Unique actions: 26
Feature matrix shape: (26, 75)
Feature stats - min: [20.  0.  0.  0.  0.  0.  0.  0.  0.  0.], max: [25.  4.  3.  3.  2.  2.  0.  1.  0.  1.]

=== Starting MaxEnt IRL ===
States: 26, Actions: 26, Features: 75
Found 26 unique (state, action) pairs
Empirical feature expectation norm: 208.0814
Iteration 0: Gradient norm = 11.821327, Reward range = [-0.26, 0.20], Best = 11.821327, LR = 0.0500
Iteration 10: Gradient norm = 11.796879, Reward range = [-0.21, 2.84], Best = 11.796879, LR = 0.0500
Iteration 20: Gradient norm = 11.749953, Reward range = [-0.21, 7.48], Best = 11.749953, LR = 0.0500
Iteration 30: Gradient norm = 11.695342, Reward range = [-0.21, 12.78], Best = 11.695342, LR = 0.0500
Iteration