In [8]:
import gym
from gym import spaces
import numpy as np
import yaml
import pandas as pd 
import os 
import random


In [3]:
def load_config(file_path='config.yaml'):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

In [9]:
CONFIG = load_config("/home/julian/git-repo/juliangdz/GovernanceIRP/Autonomous-Governance-in-Disaster-Management/rl_decision_maker/configs/config.yaml")

class TaskSequenceEnv(gym.Env):
    def __init__(self):
        super(TaskSequenceEnv, self).__init__()
        self.CONFIG = CONFIG
        self.tasks = self.CONFIG['tasks']
        # Load the Datasets for each task
        self.info_dataset = self._get_data_based_on_task("info")
        self.human_dataset = self._get_data_based_on_task("human")
        self.damage_dataset = self._get_data_based_on_task("damage")
        self.satellite_dataset = self._get_data_based_on_task("satellite")
        self.drone_dataset = self._get_data_based_on_task("drone")
        # Store the Seen Indexes of the Records for each Dataset 
        self.seen_info = []
        self.seen_human = []
        self.seen_damage = []
        self.seen_satellite = []
        self.seen_drone = []
        
        self.task_index = 0
        self.tree_counter = 0
        self.failed_tree_counter = 0
        self.current_task_info = None
        self.ground_truth = None
        self.tree_score = 0

        self.action_space = spaces.Discrete(5)  # Actions 0, 1, 2, 3, 4
        self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
        
    def _random_select_record_from_dataset(self, task: str):
        
        if task == "info":
            dataset = self.info_dataset
        elif task == "human":
            dataset = self.human_dataset
        elif task == "damage":
            dataset = self.damage_dataset
        elif task == "satellite":
            dataset = self.satellite_dataset
        elif task == "drone":
            dataset = self.drone_dataset
        else:
            dataset = self.info_dataset
        
        if dataset.empty:
            raise ValueError("The dataset is empty")

        if task == "info":
            remaining_indices = list(set(dataset.index) - set(self.seen_info))
        elif task == "human":
            remaining_indices = list(set(dataset.index) - set(self.seen_human))
        elif task == "damage":
            remaining_indices = list(set(dataset.index) - set(self.seen_damage))
        elif task == "satellite":
            remaining_indices = list(set(dataset.index) - set(self.seen_satellite))
        elif task == "drone":
            remaining_indices = list(set(dataset.index) - set(self.seen_drone))
        else:
            remaining_indices = list(set(dataset.index) - set(self.seen_info))

        if not remaining_indices:
            # Most likely wont happen because i should only do the remaining check for records seen in that episode
            raise ValueError("All records have been seen")

        selected_idx = random.choice(remaining_indices)
        if task == "info":
            self.seen_info.append(selected_idx)
        elif task == "human":
            self.seen_human.append(selected_idx)
        elif task == "damage":
            self.seen_damage.append(selected_idx)
        elif task == "satellite":
            self.seen_satellite.append(selected_idx)
        elif task == "drone":
            self.seen_drone.append(selected_idx)
        else:
            self.seen_info.append(selected_idx)
        
        return dataset.loc[selected_idx]["ground_truth"],dataset.loc[selected_idx]["prediction_conf"]
    
    def _get_data_based_on_task(self,task:str):
        dataset = pd.read_csv(os.path.join(self.CONFIG['data_path'],task,f"{self.CONFIG['phase']}_inference_results.csv"))
        dataset = dataset[["prediction_conf","ground_truth"]]
        return dataset
    
    def reset(self):
        # Reset the Seen Records here after that episode ends
        self.seen_info = []
        self.seen_human = []
        self.seen_damage = []
        self.seen_satellite = []
        self.seen_drone = []
        
        self.task_index = 0
        self.tree_counter = 0
        self.failed_tree_counter = 0
        self.tree_score = 0
        self.current_task_info, self.ground_truth = self.get_task_data()
        return self.current_task_info

    def step(self, action):
        reward = 0
        done = False
        
        if action == 4:
            reward = -1
            task = self.tasks[self.task_index]
            self.current_task_info, self.ground_truth = self._random_select_record_from_dataset(task)
        elif action == self.ground_truth:
            reward = 1
            self.task_index += 1
            if self.task_index >= len(self.tasks):
                self.tree_counter += 1
                done = True
            else:
                task = self.tasks[self.task_index]
                self.current_task_info, self.ground_truth = self._random_select_record_from_dataset(task)
        else:
            reward = -5
            self.failed_tree_counter += 1
            done = True

        self.tree_score += reward
        return self.current_task_info, reward, done, {}

    def render(self, mode='human'):
        pass

    def close(self):
        pass
