In [1]:
#!pip install gymnasium

In [2]:


# # Add the project root directory to Python path
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.append(project_root)

#--------------------------------------------------------------------------------------------#


In [1]:
import sys
import os
import json
import gymnasium as gym
from langchain.output_parsers import RegexParser
from langchain.schema import (
    HumanMessage,
    SystemMessage,
)
import numpy as np
from langchain_openai import ChatOpenAI
from utils.LLM_utils import get_completion_gpt4
from utils.jsonparser_utils import clean_llm_output, json_to_dataframe
from actor_agents.document_extractor import document_extractor_agent, baseline_extractor_agent 
from actor_agents.document_classifier import classify_document_with_llm
from actor_agents.schema_builder import schema_building_with_llm
from src.action_space.meta_prompting_agent import adjust_prompt
from evaluation.scoring import calculate_exact_match, calculate_similarity




### Base Environment

In [5]:
class DataExtractionEnv(gym.Env):
    """
    Custom Gymnasium environment for the Agentic Data Extraction process.
    The agent interacts with the environment to improve data extraction accuracy.
    Observations include Exact Match Score and Similarity Score.
    Actions represent adjustments to prompt engineering strategies.
    """
    def __init__(self, baseprompt, document_type, document, schema, groundtruth):
        super(DataExtractionEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(5)  # 5 possible prompt adjustments
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)  # [Exact Match, Similarity]
        
        # Store document extraction related data
        self.baseprompt = baseprompt
        self.document_type = document_type
        self.document = document
        self.schema = schema
        self.groundtruth = groundtruth
        self.task_type = "form-like document data extraction"
        self.state = None




    def step(self, action):

        print(f"\n.............................NEW ITERATION BEGINS.....................................")
        
        # Get updated prompt using meta-prompting agent
        updated_prompt = adjust_prompt(
            actor_prompt=self.baseprompt,
            task_type=self.task_type,
            state=self.state,
            action=action,
            generated_output=self.last_output if hasattr(self, 'last_output') else None,
            groundtruth=self.groundtruth
        )
        self.current_prompt = updated_prompt

        resolved_updated_prompt = document_extractor_agent(updated_prompt, self.document_type, self.document, self.schema)


        # Generate new output using the updated prompt
        self.last_output = clean_llm_output(get_completion_gpt4([{"role": "user", "content": resolved_updated_prompt}],).choices[0].message.content)
        
        print(f"\nUpdated Prompt: {resolved_updated_prompt}")
        print("Updated Output:", self.last_output)



        # Calculate scores (you'll need to implement these scoring functions)
        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)

        # Update state
        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)

        # Calculate reward
        reward = exact_match_score * similarity_score - abs(action - 2) * 0.05

        # Check if task is complete
        done = bool(exact_match_score >= 0.90 and similarity_score >= 0.95)

        # Create info dictionary with useful information
        info = {
            'exact_match': exact_match_score,
            'similarity': similarity_score,
            'State' : self.state,
            'Updated Prompt': resolved_updated_prompt,
            'Updated Output': self.last_output
        }


        print(f"State: {self.state}")
        print(f"Done: {done}")


        return self.state, reward, done, info

    def reset(self):
        """Reset the environment to initial state"""
        # Reset prompt to initial state

        self.current_prompt = self.baseprompt
        self.resolved_current_prompt = document_extractor_agent(self.current_prompt, self.document_type, self.document, self.schema)

        self.last_output = clean_llm_output(get_completion_gpt4([{"role": "user", "content": self.resolved_current_prompt}],).choices[0].message.content)
        print(f"Start Prompt:\n {self.resolved_current_prompt}")
        print(f"Start Output:\n {self.last_output}")
        
        # Initialize all metrics
        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)

        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)  # Initial values for all metrics
        return self.state, {}  # Return state and empty info dict for gymnasium compatibility
    



### Iterative Best Score Environment

In [25]:
class DataExtractionEnv(gym.Env):
    """
    Custom Gymnasium environment for the Agentic Data Extraction process.
    The agent interacts with the environment to improve data extraction accuracy by iteratively updating observation.
    Observations include Exact Match Score and Similarity Score.
    Actions represent adjustments to prompt engineering strategies.
    """

    def __init__(self, baseprompt, document_type, document, schema, groundtruth):
        super(DataExtractionEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(5)  # 5 possible prompt adjustments
        self.observation_space = gym.spaces.Box(
            low=0, 
            high=1, 
            shape=(2,),  # [Exact Match, Similarity]
            dtype=np.float32
        )
        
        # Store document extraction related data
        self.baseprompt = baseprompt
        self.document_type = document_type
        self.document = document
        self.schema = schema
        self.groundtruth = groundtruth
        self.current_prompt = baseprompt
        self.task_type = "form-like document data extraction"
        
        # Track best scores and results
        self.best_exact_match = 0.0
        self.best_similarity = 0.0
        self.best_output = None
        self.best_prompt = None
        
        # Track consecutive non-improvements
        self.non_improvement_count = 0
        self.max_non_improvements = 2  # Stop after 2 consecutive non-improvements
        self.current_step = 0
        
        self.state = None


    def step(self, action):
        self.current_step += 1

        print(f"\n.............................NEW ITERATION BEGINS.....................................")
        
        # Get updated prompt using meta-prompting agent
        updated_prompt = adjust_prompt(
            actor_prompt=self.current_prompt,
            task_type=self.task_type,
            state=self.state,
            action=action,
            generated_output=self.last_output if hasattr(self, 'last_output') else None,
            groundtruth=self.groundtruth
        )
        self.current_prompt = updated_prompt

        resolved_updated_prompt = document_extractor_agent(updated_prompt, 
                                                           self.document_type, 
                                                           self.document, self.schema)

        # Generate new output
        self.last_output = clean_llm_output(get_completion_gpt4(
            [{"role": "user", "content": resolved_updated_prompt}]
        ).choices[0].message.content)


        # Calculate scores
        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)
        
        # Update state
        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)
        
        # Calculate combined score for comparison
        current_combined_score = exact_match_score + similarity_score
        best_combined_score = self.best_exact_match + self.best_similarity
        
        # Check if current scores are better than best scores
        if current_combined_score > best_combined_score:
            self.best_exact_match = exact_match_score
            self.best_similarity = similarity_score
            self.best_output = self.last_output
            self.best_prompt = self.current_prompt
            self.non_improvement_count = 0
            improved = True
        else:
            self.non_improvement_count += 1
            improved = False

        # Calculate reward
        reward = current_combined_score - best_combined_score

        # Determine if we should terminate
        done = self.non_improvement_count >= self.max_non_improvements

        # Create info dictionary with useful information
        info = {
            'exact_match': exact_match_score,
            'similarity': similarity_score,
            'best_exact_match': self.best_exact_match,
            'best_similarity': self.best_similarity,
            'improved': improved,
            'non_improvement_count': self.non_improvement_count,
            'steps': self.current_step
        }

        print(f"\nStep {self.current_step}")
        print(f"\nUpdated Prompt: {resolved_updated_prompt}")
        print("Updated Output:\n", self.last_output)
        print(f"Current Scores - Exact Match: {exact_match_score:.4f}, Similarity: {similarity_score:.4f}")
        print(f"Best Scores    - Exact Match: {self.best_exact_match:.4f}, Similarity: {self.best_similarity:.4f}")

        
        if done:
            print("\nTerminating due to no improvements in last two updates")
            print(f"Best Results Achieved:")
            print(f"Exact Match: {self.best_exact_match:.4f}")
            print(f"Similarity: {self.best_similarity:.4f}")
            print(f"Best Prompt:\n {self.best_prompt}")
            print(f"\nBest Output:\n {self.best_output}")




        return self.state, reward, done, info

    def reset(self):
        """Reset the environment to initial state"""
        self.current_step = 0
        self.non_improvement_count = 0
        self.current_prompt = self.baseprompt
        self.resolved_current_prompt = document_extractor_agent(self.current_prompt, self.document_type, self.document, self.schema)
        
        # Generate initial output
        self.last_output = clean_llm_output(get_completion_gpt4(
            [{"role": "user", "content": self.resolved_current_prompt}]).choices[0].message.content)

        print(f"Start Prompt:\n {self.resolved_current_prompt}")
        print(f"Start Output:\n {self.last_output}")
        
        # Calculate initial scores
        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)
        
        # Initialize best scores with initial scores
        self.best_exact_match = exact_match_score
        self.best_similarity = similarity_score
        self.best_output = self.last_output
        self.best_prompt = self.current_prompt
        
        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)
        return self.state, {}

    def get_best_results(self):
        """Return the best results achieved during the episode"""
        return {
            'best_exact_match': self.best_exact_match,
            'best_similarity': self.best_similarity,
            'best_output': self.best_output,
            'best_prompt': self.best_prompt
        }

### Step Count Environment

In [10]:
class DataExtractionEnv(gym.Env):
    """
    Custom Gymnasium environment for the Agentic Data Extraction process.
    The agent interacts with the environment to improve data extraction accuracy.
    Observations include Exact Match Score and Similarity Score.
    Actions represent adjustments to prompt engineering strategies.
    """
    def __init__(self, baseprompt, document_type, document, schema, groundtruth):
        super(DataExtractionEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(5)  # 5 possible prompt adjustments
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)  # [Exact Match, Similarity]
        
        # Store document extraction related data
        self.baseprompt = baseprompt
        self.document_type = document_type
        self.document = document
        self.schema = schema
        self.groundtruth = groundtruth
        self.current_prompt = baseprompt  # Initial prompt
        self.task_type = "form-like document data extraction"
        self.state = None


        # Add termination thresholds
        self.exact_match_threshold = 0.95
        self.similarity_threshold = 0.95
        self.max_steps = 5  # Add maximum steps limit
        self.current_step = 0



    def step(self, action):
        print(f"\n.............................NEW ITERATION BEGINS.....................................")
        
        """Execute one step in the environment"""
        self.current_step += 1

        # Get updated prompt using meta-prompting agent
        updated_prompt = adjust_prompt(
            actor_prompt=self.baseprompt,
            task_type=self.task_type,
            state=self.state,
            action=action,
            generated_output=self.last_output if hasattr(self, 'last_output') else None,
            groundtruth=self.groundtruth
        )
        self.current_prompt = updated_prompt

        resolved_updated_prompt = document_extractor_agent(updated_prompt, self.document_type, self.document, self.schema)


        # Generate new output using the updated prompt
        self.last_output = clean_llm_output(get_completion_gpt4([{"role": "user", "content": resolved_updated_prompt}],).choices[0].message.content)

        print(f"\nStep {self.current_step}/{self.max_steps}")
        print(f"\nUpdated Prompt: {resolved_updated_prompt}")
        print("Updated Output:", self.last_output)

        # Calculate scores (you'll need to implement these scoring functions)
        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)

        # Update state
        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)

        # Calculate reward
        reward = exact_match_score * similarity_score - abs(action - 2) * 0.05

         # Check termination conditions
        done = bool(
            (exact_match_score >= self.exact_match_threshold and 
             similarity_score >= self.similarity_threshold) or
            self.current_step >= self.max_steps
        )

        # Add info dict with useful debugging information
        info = {
            'exact_match': exact_match_score,
            'similarity': similarity_score,
            'steps': self.current_step,
            'max_steps_reached': self.current_step >= self.max_steps,
            'success': exact_match_score >= self.exact_match_threshold and similarity_score >= self.similarity_threshold
        }


        print(f"State: {self.state}")
        print(f"Done: {done}")


        return self.state, reward, done, info



    def reset(self):
        """Reset the environment to initial state"""
        self.current_step = 0  # Reset step counter


        self.current_prompt = self.baseprompt
        self.resolved_current_prompt = document_extractor_agent(self.current_prompt, self.document_type, self.document, self.schema)

        self.last_output = clean_llm_output(get_completion_gpt4([{"role": "user", "content": self.resolved_current_prompt}],).choices[0].message.content)
        print(f"Start Prompt:\n {self.resolved_current_prompt}")
        print(f"Start Output:\n {self.last_output}")


        exact_match_score = calculate_exact_match(self.last_output, self.groundtruth)
        similarity_score = calculate_similarity(self.last_output, self.groundtruth)

        # Initialize all metrics
        self.state = np.array([exact_match_score, similarity_score], dtype=np.float32)  # Initial values for all metrics
        return self.state, {}  # Return state and empty info dict for gymnasium compatibility
    

### Gymansium Agent

In [11]:
class GymnasiumAgent:

    @classmethod
    def get_docs(cls, env):
        return env.unwrapped.__doc__

    def __init__(self, model, env):
        self.model = model
        self.env = env
        self.docs = self.get_docs(env)

        self.instructions = """
Your goal is to maximize your return, i.e., the sum of the rewards you receive.
I will give you an observation, reward, termination flag, truncation flag, and the return so far, formatted as:

Observation: <observation>
Reward: <reward>
Termination: <termination>
Truncation: <truncation>
Return: <sum_of_rewards>

You will respond with an action, formatted as:

Action: <action>

where you replace <action> with your actual action.
"""
        self.action_parser = RegexParser(
            regex=r"Action: (.*)", output_keys=["action"],
        )

    def interact(self):
        observation, _ = self.env.reset()
        terminated = False
        total_reward = 0

        while not terminated:
            # Format observation for better readability
            obs_dict = {
                'Exact Match': observation[0],
                'Similarity': observation[1]
            }
            print("\nCurrent State:")
            for metric, value in obs_dict.items():
                print(f"{metric}: {value:.4f}")

            # Generate a response (action) using the model
            response = self.model([
                SystemMessage(content=self.instructions),
                HumanMessage(content=f"""
Current observation: {obs_dict}
Current total reward: {total_reward:.4f}
Task completed: {terminated}

Based on these metrics, what action (0-4) would you take to improve the extraction?
Remember to respond ONLY with "Action: <number>"
""")
            ])

            try:
                action = int(self.action_parser.parse(response.content)['action'])
                if not (0 <= action <= 4):
                    raise ValueError("Action must be between 0 and 4")
            except (ValueError, KeyError) as e:
                print(f"Invalid response from model: {response.content}")
                print("Defaulting to action 0")
                action = 0

            action = int(self.action_parser.parse(response.content)['action'])

            # Perform action in the environment
            observation, reward, terminated, info = self.env.step(action)
            total_reward += reward

            print(f"\nAction taken: {action}")
            print(f"Reward: {reward:.4f}")
            print(f"\nTerminated: {terminated}")
            print(f"Total Return: {total_reward:.4f}")
            print("Metrics:", info)

        print("\nTask completed successfully!")

In [12]:
def load_prompt_from_file(filename):
    """Load prompt template from a text file"""
    # Get the absolute path to the project root directory
    project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
    
    # Construct path to the prompts directory
    prompt_path = os.path.join(project_root, 'src', 'actor_agents', 'Prompts', filename)
    
    try:
        with open(prompt_path, 'r', encoding='utf-8') as file:
            return file.read()
    except FileNotFoundError:
        raise FileNotFoundError(f"Prompt file not found at: {prompt_path}")

# Load the base prompt template
document_type = 'invoice'
try:
    base_prompt = load_prompt_from_file('Invoice_prompt.txt')
    print("Successfully loaded prompt template")
    print(base_prompt)
except FileNotFoundError as e:
    print(f"Error loading prompt: {e}")

Successfully loaded prompt template
### Instructions:
You are a data extraction tool capable of extracting data from each page of an invoice.

1. Please extract the data in this invoice and format it to the given output JSON schema.

2. Extract all key-value pairs from the invoice.

3. If there are tables in the invoice, capture all of the rows and columns in the JSON object. 
Even if a column is blank, include it as a key in the JSON object with a null value.

4. If a row is blank denote missing fields with "null" values.

5. If the page contains no charge data, please output an empty JSON object and don't make up any data.

6. Don't interpolate or make up data.

7. Please maintain the table structure of the charges, i.e. capture all of the rows and columns in the JSON object.

8. Ensuring the order of key-value pairs and tabular data aligns with the original text.


The language model must interpret and execute these extraction and formatting instructions accurately.

Perform the tas

In [13]:
## INPUT PARAMS TEST

if __name__ == "__main__":
    # Example initial state and actions
    invoice = f"""
-----------------Invoice------------------
                              Page 1 of 3

Invoice Number: INV-12345
Customer: XYZ Corp
Invoice Date: 2024-06-01


Item    Quantity    Price     Total
item_1     5         $100      500
item_2     10        $50       500
item_3     6         $10       60

					Subtotal: 1060
					Total GST: 500
					Total Amount: $1560
--------------------------------------------
"""
    schema = schema = {
    "invoice_number": "string",
    "customer": "string",
    "invoice_date": "yyyy-mm-dd",
    "sub_total": "number",
    "total_GST": "number",
    "total_amount": "number",
    "Line_Items": [
        {
            "item": "string",
            "quantity": "number",
            "price": "number",
            "total": "number"
        }
    ] 
    }
   
    generated_output = {
    "invoice_number": "INV-12345",
    "invoice_date": "2024-06-01",
    "sub_total": 1060,
    "total_amount": 1560, 
    "Line_Items": [
        {
        "item": "item_1",
        "quantity": 5,
        "price": "$100",
        "total": 500
        },
        {
        "item": "item_2",
        "quantity": 10,
        "price": "$50",
        "total": 500
        }
    ]
    } 
    groundtruth = {
    "invoice_number": "INV-12345",
    "customer": "XYZ Corp",
    "invoice_date": "2024-06-01",
    "sub_total": 1060,
    "total_GST":500,
    "total_amount": 1560,
    "Line_Items": [
        {
        "item": "item_1",
        "quantity": 5,
        "price": 100,
        "total": 500
        },
        {
        "item": "item_2",
        "quantity": 10,
        "price": 50,
        "total": 500
        },
        {
        "item": "item_3",
        "quantity": 6,
        "price": 10,
        "total": 60
        }
    ]
    }
    
    # print(document_extractor_agent(base_prompt, document_type, invoice, schema))
    # print(calculate_exact_match(generated_output,groundtruth))

In [14]:
# Create environment and agent
env = DataExtractionEnv(
    baseprompt=base_prompt,
    document_type=document_type, 
    document=invoice, 
    schema=schema, 
    groundtruth=groundtruth
    )
# env.exact_match_threshold = 0.95  # Adjust if needed
# env.similarity_threshold = 0.95   # Adjust if needed
# env.max_steps = 10               # Adjust if needed

agent = GymnasiumAgent(model=ChatOpenAI(temperature=0.2), env=env)

# Run the interaction
agent.interact()

Start Prompt:
 
### Instructions:
You are a data extraction tool capable of extracting data from each page of an invoice.

1. Please extract the data in this invoice and format it to the given output JSON schema.

2. Extract all key-value pairs from the invoice.

3. If there are tables in the invoice, capture all of the rows and columns in the JSON object. 
Even if a column is blank, include it as a key in the JSON object with a null value.

4. If a row is blank denote missing fields with "null" values.

5. If the page contains no charge data, please output an empty JSON object and don't make up any data.

6. Don't interpolate or make up data.

7. Please maintain the table structure of the charges, i.e. capture all of the rows and columns in the JSON object.

8. Ensuring the order of key-value pairs and tabular data aligns with the original text.


The language model must interpret and execute these extraction and formatting instructions accurately.

Perform the task as per above instr