In [None]:
# import gym
import gymnasium as gym
from gymnasium import spaces
from collections import defaultdict


class DeliveryEnv(gym.Env):
    def __init__(
        self,
        orders: list,
        driver_data: pd.DataFrame,
        schedule_data: pd.DataFrame,
        weather_service: WeatherService,
    ):
        super(DeliveryEnv, self).__init__()

        # self.simulator = DeliverySimulator(
        #     orders, driver_data, schedule_data, weather_service
        # )
        self.orders = sorted(orders, key=lambda o: o.datetime)
        self.weather = weather_service
        self.driver_schedule = self._load_driver_schedule(schedule_data)
        self.driver_attempts = self._load_driver_attempts(driver_data)
        self.drivers_by_id = {}  # Cache all drivers by ID
        self.area_drivers = self._group_drivers_by_area()
        # Track drivers with updated location
        self.drivers_with_updated_location = set()
        # Define action space (continuous commission rate between 0 and 1)
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)

        # Define state space (order + specific driver attributes)
        self.observation_space = spaces.Dict(
            {
                "customer_price": spaces.Box(
                    low=0.0, high=1.0, shape=(1,), dtype=np.float32
                ),
                "pickup_area": spaces.Discrete(501),
                "dropoff_area": spaces.Discrete(501),
                "hour_of_day": spaces.Discrete(24),
                "day_of_week": spaces.Discrete(7),
                "weather": spaces.Discrete(4),
                "driver_area": spaces.Discrete(501),
                "working_status": spaces.Discrete(2),
            }
        )

        self.orders_by_day = defaultdict(list)
        order: Order
        for order in self.orders:
            order_day = order.datetime.date()
            self.orders_by_day[order_day].append(order)

        # Tracking variables
        self.assigned_order = 0  # Tracks # of unassigned orders
        self.current_day_index = 0  # Tracks training epoch (day index)
        self.current_order_index = 0  # Tracks current order within the day
        self.current_driver_index = 0
        self.current_day = None  # Current date being trained
        self.updated_drivers = (
            set()
        )  # Track drivers who have accepted at least one order
        self.next_order = False
        self.episode_rewards = 0
        self.episode_steps = 0
        self.total_driver_commission = 0.0
        self.max_steps = 30000

        self.current_orders = self.orders_by_day[self.current_day]

    def _load_driver_schedule(self, schedule_data: pd.DataFrame):
        """Loads driver work schedules from a CSV file into a dictionary."""
        schedule = defaultdict(set)
        for _, row in schedule_data.iterrows():
            driver_id = row["driver_id"]
            date = row["date"]
            hour = row["hour"]
            schedule[(driver_id, date)].add(hour)
        return schedule

    def _load_driver_attempts(self, driver_data: pd.DataFrame):
        """Loads driver assignment attempts, tracking all instances a driver receives an order."""
        attempts = defaultdict(list)
        for _, row in driver_data.iterrows():
            order_id = row["order_id"]
            driver_id = row["driver_id"]
            datetime = row["datetime"]
            lat, lon, area = row["driver_lat"], row["driver_lon"], row["driver_area"]
            work_time_minutes = row["work_time_minutes"]
            attempts[order_id].append(
                (driver_id, datetime, lat, lon, area, work_time_minutes)
            )
        return attempts

    def _group_drivers_by_area(self):
        """Groups drivers by their current area for efficient order assignment."""
        area_drivers = defaultdict(list)
        for order_id, driver_attempts in self.driver_attempts.items():
            for (
                driver_id,
                datetime,
                lat,
                lon,
                area,
                work_time_minutes,
            ) in driver_attempts:
                if driver_id not in self.drivers_by_id:
                    driver = Driver(
                        driver_id=driver_id,
                        current_lat=lat,
                        current_lon=lon,
                        current_area=area,
                        work_time_minutes=work_time_minutes,
                    )
                    # driver.model = DeliverySimulator.shared_model
                    self.drivers_by_id[driver_id] = driver
                    area_drivers[area].append(driver)
        return area_drivers

    def _get_driver_pool(self, order: Order):
        """Retrieve the pool of drivers (historical + internal updates) while filtering out moved drivers."""

        combined_attempts = list(self.driver_attempts.get(order.order_id, []))
        valid_drivers = []

        # Include dynamic drivers from updated locations
        driver: Driver
        for driver in self.area_drivers.get(order.pickup_area, []):
            if driver.driver_id not in [
                a[0] for a in combined_attempts
            ]:  # Avoid duplicates
                if order.datetime.hour in self.driver_schedule.get(
                    (driver.driver_id, order.datetime.date()), set()
                ):
                    combined_attempts.append(driver)

        # Filter drivers to ensure they are actually in the pickup area
        for attempt in combined_attempts:
            driver_id, datetime, lat, lon, area, work_time_minutes = attempt
            driver: Driver = self.drivers_by_id.get(driver_id)

            if not driver:
                continue

            # If the driver has moved, **skip them** unless they are actually in the correct pickup area
            if driver_id in self.drivers_with_updated_location:
                if driver.current_area != order.pickup_area:
                    continue  # Driver moved to another area, so exclude them
                # Otherwise, use updated location
                driver.current_lat, driver.current_lon, driver.current_area = (
                    lat,
                    lon,
                    area,
                )  # Update location

            valid_drivers.append(driver)

        if len(valid_drivers) == 0:
            print(
                f"WARNING: No available drivers for Order {order.order_id} in pickup area {order.pickup_area}!"
            )
        return valid_drivers

        # normalize states to make RL training more stable

    def _normalize_state(self, order: Order):
        """Normalize key order attributes for RL training."""
        return {
            "pickup_area": int(order.pickup_area),
            "dropoff_area": int(order.dropoff_area),
            "hour_of_day": int(order.hour_of_day),
            "weather_code": int(self.weather.get_weather_code(order.datetime)),
            "customer_price": order.customer_price
            / 10100000.0,  #  Normalized by max price threshold
            "commissionPercent": order.commissionPercent
            / 100.0,  # Already between 0-1 (no changes needed)
        }

    def _get_observation(self):
        """Extracts order & driver state for RL input."""
        if self.current_order_index >= len(self.current_orders):
            return None
            # return np.zeros(self.observation_space.shape, dtype=np.float32) # return dummy observations

        order = self.current_orders[self.current_order_index]
        normalized_state = self._normalize_state(order)

        if self.next_order:
            self.current_order_driver_pool = self._get_driver_pool(order)
            self.current_driver_index = 0

        driver: Driver = self.current_order_driver_pool[self.current_driver_index]
        order: Order
        
        obs_dict = {
            "customer_price": np.array(
                [normalized_state["customer_price"]], dtype=np.float32
            ),
            "pickup_area": normalized_state["pickup_area"],
            "dropoff_area": normalized_state["dropoff_area"],
            "hour_of_day": normalized_state["hour_of_day"],
            "day_of_week": order.datetime.weekday(),
            "weather": normalized_state["weather_code"],
            #'driver_id': driver.driver_id,
            "driver_area": driver.current_area,
            "working_status": 1 if driver.available else 0,
        }

        return obs_dict

    def _is_done(self):
        """Terminates an episode at the end of the day."""
        if self.current_order_index >= len(self.current_orders):
            print("Current training day completed!")
            print("# of assigned orders ", self.assigned_order)
            return True  # End the current day and reset

        return False

    def _is_driver_working(self, driver: Driver, datetime: datetime):
        """Checks if a driver is scheduled to work at a given time."""

        # date = order.datetime.date()
        # hour = order.datetime.hour
        date = datetime.date()
        hour = datetime.hour
        return hour in self.driver_schedule.get((driver.driver_id, date), set())

    def _get_next_order_time(self):
        """Retrieve the timestamp for the next order."""
        current_order: Order
        if self.current_order_index + 1 < len(self.current_orders):
            current_order = self.current_orders[self.current_order_index + 1]
        else:
            current_order = self.current_orders[self.current_order_index]
        return current_order.datetime

    def reset(self):
        """Resets environment at the start of each operational day (8 AM)."""

        self.current_day = list(self.orders_by_day.keys())[self.current_day_index]
        self.current_orders = self.orders_by_day[self.current_day]

        self.current_order_index = 0
        self.current_driver_index = 0
        self.updated_drivers.clear()

        self.episode_rewards = 0
        self.episode_steps = 0
        self.assigned_order = 0
        self.total_driver_commission = 0.0

        # Count order-driver pairs correctly
        self.current_day_order_driver_pairs = sum(
            len(self._get_driver_pool(order)) for order in self.current_orders
        )
        print(
            f"DEBUG: Resetting for day {self.current_day}, Orders: {len(self.current_orders)}, Order-Driver Pairs: {self.current_day_order_driver_pairs}"
        )

        if len(self.current_orders) > 0:
            self.current_order_driver_pool = self._get_driver_pool(
                self.current_orders[self.current_order_index]
            )
        else:
            self.current_order_driver_pool = []
            print("EBUG: no order at all!")

        obs = self._get_observation()

        return obs

    def step(self, action):
        """Processes one order-driver pair, ensuring driver availability updates for the next order."""
        order: Order
        driver: Driver

        self.next_order = False

        order = self.current_orders[self.current_order_index]

        driver = self.current_order_driver_pool[self.current_driver_index]

        # Offer commission rate
        order.commissionPercent = np.clip(action[0], 0.0, 1.0)
        order.driver_commission = order.customer_price * (1 - order.commissionPercent)
        self.total_driver_commission += order.driver_commission
        weather_code = self.weather.get_weather_code(order.datetime)

        accepted = driver.decide_acceptance(order, weather_code)
        # pdb.set_trace()

        # Update working status for next order no matter what
        driver.available = self._is_driver_working(driver, self._get_next_order_time())

        if accepted:
            self.assigned_order += 1
            reward = order.customer_price * order.commissionPercent

            # Track old area before moving
            old_area = driver.current_area

            # Track movement and availability updates
            driver.update_location(
                order.dropoff_lat, order.dropoff_lon, order.dropoff_area
            )
            self.updated_drivers.add(driver.driver_id)

            # Update area_drivers mapping
            if driver in self.area_drivers[old_area]:
                self.area_drivers[old_area].remove(driver)
            self.area_drivers[driver.current_area].append(driver)

            # Move to next order
            self.current_order_index += 1
            self.next_order = True
            self.current_driver_index = 0

        else:
            # Move to next driver for the same order
            reward = 0
            self.current_driver_index += 1
            if self.current_driver_index >= len(self.current_order_driver_pool):
                self.current_order_index += 1
                self.next_order = True
                # self.unassigned_order += 1
                self.current_driver_index = 0

        self.episode_steps += 1
        self.episode_rewards += reward

        done = self._is_done(order.datetime)

        info = {}

        if done:
            info["episode"] = {
                "r": self.episode_rewards,  # Report FINAL totals
                "l": self.episode_steps,
                "a": self.assigned_order,
                "o": len(self.current_orders),
                "c": self.total_driver_commission,
            }

        obs = self._get_observation()

        return obs, reward, done, info