<a href="https://colab.research.google.com/github/aswinaus/Reinforcement-Learning/blob/main/Agentic_RewardFunction_GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install llama-index -q
!pip install langchain -q
!pip install langchain_experimental -q

In [None]:
import os
import nest_asyncio
nest_asyncio.apply()

In [None]:
from google.colab import userdata
# Set the OpenAI API key as an environment variable
os.environ["OPENAI_API_KEY"] =  userdata.get('OPENAI_API_KEY')

In [None]:
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import Settings
# Setup OpenAI Model and Embeddings used for indexing the documents
Settings.llm = OpenAI(model='gpt-4o-mini', temperature=0.2)
Settings.embed_model = OpenAIEmbedding(model='text-embedding-3-small')
Settings.chunk_size = 1024

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive' # Input a data dir path from your mounted Google Drive

In [None]:
# Load OECD guidelines documents for Transfer Pricing
docs_OECD_guidelines = SimpleDirectoryReader(f"{data_dir}/RAG/data/OECD/").load_data()
# Load OECD guidelines documents for Form990
docs_Form990_guidelines = SimpleDirectoryReader(f"{data_dir}/RAG/data/Form990/").load_data()

# Task
Calculate rewards based on cosine similarity, explain how to update the policy directly in GRPO for the current situation, and define the policy network for this example, specifically referring to Group Relative Policy Optimization (GRPO) in the context of a RAG system where the reward is based on cosine similarity.

## Understand the role of the policy network

### Subtask:
Clarify what aspects of the RAG system the policy network will control (e.g., retrieval parameters, generation parameters, or both).


## Choose a suitable neural network architecture

### Subtask:
Select a type of neural network (e.g., a feedforward network, recurrent network, or transformer-based model) that can process the input question and output the control signals for the RAG system.


**Reasoning**:
To select a neural network architecture suitable for processing text input (the question) and outputting control signals for the RAG system, considering the sequential nature of text and the need to capture semantic meaning, a transformer-based model is a good choice due to its effectiveness in handling sequential data and complex relationships within text.



In [None]:
# Explanation of the chosen neural network architecture

# For this task, where the input is a text question and the output needs to be control signals
# for a RAG system (e.g., parameters for retrieval or generation), a transformer-based model
# is a suitable choice.

# Justification:
# 1. Text Input: Transformers are specifically designed to process sequential data like text.
# 2. Semantic Understanding: They excel at capturing long-range dependencies and semantic relationships
#    within the input question, which is crucial for determining appropriate RAG control signals.
# 3. Contextual Embeddings: Transformer models generate rich contextual embeddings of the input text,
#    allowing the policy network to make informed decisions based on the nuances of the question.
# 4. Adaptability: Transformer architectures can be adapted with output layers to produce numerical
#    or categorical control signals as needed by the RAG system.

# Examples of transformer-based models that could be adapted for this purpose include
# models based on BERT, RoBERTa, or smaller, task-specific transformer variants.
# The specific output layer would depend on the exact control signals being predicted (e.g.,
# a dense layer for numerical parameters or a softmax layer for categorical choices).

# This choice aligns with the task of learning a policy based on textual input to influence
# the RAG system's behavior for better performance according to the defined reward.

## Define the input and output layers

### Subtask:
Define the format of the input (the question) and the output (the parameters or actions that influence the RAG system) for the chosen transformer-based policy network.


**Reasoning**:
Describe the input and output format of the chosen transformer-based policy network based on the previous subtask.



In [None]:
# 1. Input Layer Format:
# The input to the policy network will be the user's question, which is a string of text.
# Before being fed into the transformer layers, this text will undergo standard NLP preprocessing steps:
# - Tokenization: The text will be broken down into a sequence of tokens (words or sub-word units) using a tokenizer appropriate for the chosen transformer model (e.g., WordPiece for BERT, BPE for RoBERTa).
# - Embedding: The sequence of tokens will be converted into a sequence of numerical embeddings. Transformer models typically use learned token embeddings, positional embeddings (to capture token order), and potentially segment embeddings. The input to the transformer layers will be a tensor of shape (batch_size, sequence_length, embedding_dim), where:
#   - batch_size: The number of questions processed in parallel.
#   - sequence_length: The maximum number of tokens in a question (padded or truncated).
#   - embedding_dim: The dimensionality of the token embeddings.

# 2. Output Layer(s) Format:
# The output layer(s) of the policy network will produce control signals for the RAG system. Based on the understanding that the policy network controls retrieval parameters (like similarity_top_k) and potentially influences generation, the output could be structured as follows:
# - For a numerical parameter like `similarity_top_k`: A single dense layer with one output neuron, potentially followed by an activation function (e.g., ReLU to ensure non-negativity) and possibly scaled to a reasonable range. The output would be a tensor of shape (batch_size, 1).
# - For influencing generation (less direct control in this setup, but conceptually): This could be represented as a vector influencing attention mechanisms or providing context to the generation model. However, focusing on retrieval parameters as the primary policy output in this GRPO context is more straightforward.
# - For simplicity and direct control over a key retrieval parameter, let's define the output as a single numerical value representing `similarity_top_k`. The output layer will be a dense layer with 1 output neuron.

# Therefore, the output of the policy network will be a tensor of shape (batch_size, 1), representing the predicted value for `similarity_top_k` for each question in the batch.

# 3. Interpretation and Usage of the Policy Network's Output:
# The output of the policy network (the predicted `similarity_top_k` value) will be used to configure the retrieval step of the RAG system for the given question.
# - During training: The predicted `similarity_top_k` will be used to perform retrieval. The retrieved context, along with the question, will then be passed to the generation model to produce an answer. This answer will be compared to the ground truth to calculate the cosine similarity reward. This reward will be used by the GRPO algorithm to update the policy network's weights, encouraging it to predict `similarity_top_k` values that lead to higher rewards.
# - During inference: The policy network will predict `similarity_top_k` for a new question, and this value will be used directly in the retrieval process to gather context for generating the final answer.
# The predicted numerical output for `similarity_top_k` might need to be post-processed (e.g., rounded to an integer, clipped to a valid range) before being used by the RAG system's retriever.

## Implement the policy network

### Subtask:
Implement the policy network using a deep learning framework. This involves defining the transformer layers and the output layer(s) based on the input and output formats defined in the previous steps.


**Reasoning**:
Implement the policy network using PyTorch, defining the transformer layers and the output layer as specified in previous steps.

In essence, this network takes a question, processes it through a pre-trained transformer to understand its context, and then uses a simple linear layer to predict a non-negative numerical value intended to represent the optimal similarity_top_k for retrieving documents for that question.

In the context of Reinforcement Learning (RL), a policy network is a neural network that learns to map states (in our case, the user's question) to actions (the parameters or decisions that control the RAG system).

**State:** The input is the user's question. The policy network processes this question to understand its meaning and context.

**Action:** The output of the policy network is a value (or values) that influences how the RAG system operates. In the code we just discussed, the policy network's action space was initially simplified to predicting a single value: similarity_top_k, which determines how many relevant documents are retrieved. In the modified code, it predicts parameters for a distribution from which similarity_top_k is sampled.

The goal of training the policy network using an algorithm like GRPO is to adjust its internal parameters (the weights and biases of the neural network) so that, when presented with a question, it predicts/samples actions (like a specific similarity_top_k) that lead to higher rewards (in our case, higher cosine similarity between the generated answer and the ground truth).

So, the policy network's role is to learn the optimal strategy for configuring the RAG system based on the input question to maximize the desired outcome (answer quality, measured by the reward).



# Task
**Implement the policy update rule for the GRPO algorithm using cosine similarity as the reward signal to adjust the policy network's parameters in the provided notebook.**

## Modify policy network output

### Subtask:
Adjust the `RAGPolicyNetwork` to output parameters for a distribution over `similarity_top_k`, such as the mean and log-variance of a Gaussian distribution.


**Reasoning**:
Modify the RAGPolicyNetwork class to output parameters for a Gaussian distribution (mean and log-variance) over `similarity_top_k`.



## Implement action sampling and log probability calculation

### Subtask:
Implement functions to sample `similarity_top_k` from the Gaussian distribution predicted by the policy network and calculate the log probability of the sampled action.


**Reasoning**:
Implement functions to sample the action (similarity_top_k) from the predicted Gaussian distribution and calculate the log probability of the sampled action, using the predicted mean and log-variance from the policy network.



## Implement baseline calculation

### Subtask:
Create a function to calculate a baseline for the rewards, such as the average reward in a batch.


**Reasoning**:
Define a function to calculate the mean of a list or tensor of rewards to be used as a baseline.



## Set up training loop

### Subtask:
Structure a training loop that iterates through the dataset, performs forward passes with the policy network, executes the RAG system with sampled actions, calculates rewards and advantages, and computes the policy gradient.


**Reasoning**:
Structure a training loop that iterates through the dataset, performs forward passes with the policy network, executes the RAG system with sampled actions, calculates rewards and advantages, and computes the policy gradient.



## Implement policy update

### Subtask:
Apply the calculated policy gradient to update the parameters of the policy network using an optimizer.


**Reasoning**:
Apply the calculated policy gradient to update the parameters of the policy network using the optimizer.



## Evaluate and refine

### Subtask:
After training, evaluate the performance of the policy-controlled RAG system and refine the implementation or hyperparameters as needed.


**Reasoning**:
Evaluate the performance of the policy-controlled RAG system after the training loop. This involves setting the policy network to evaluation mode, using a dataset (can be the same as training for demonstration), iterating through questions, predicting/sampling `similarity_top_k`, executing the RAG query, calculating the cosine similarity reward, and finally reporting the average reward.



# Task
Implement observability metrics for the training process of a policy network using Weights and Biases.

## Identify key metrics

### Subtask:
Determine which metrics are most important to track for monitoring the training process of the policy network (e.g., epoch number, batch number, average policy loss per batch, average reward per batch, average predicted similarity_top_k per batch, average advantage per batch).


**Reasoning**:
Identify and list the key metrics for monitoring the training process of the policy network.



## Integrate weights & biases

### Subtask:
Install the `wandb` library and initialize a Weights & Biases run at the beginning of the training script.


**Reasoning**:
Install the wandb library and initialize a Weights & Biases run at the beginning of the training script as requested by the subtask.



## Log hyperparameters

### Subtask:
Log the training hyperparameters (e.g., learning rate, batch size, number of epochs) to Weights & Biases.


**Reasoning**:
Define a dictionary containing the training hyperparameters and log it to the initialized Weights & Biases run.



## Integrate metric calculation into the training loop

### Subtask:
Modify the training loop to calculate the chosen metrics for each batch and/or epoch.


**Reasoning**:
Modify the training loop to calculate the selected metrics for each batch and epoch, including average policy loss, average reward, average predicted similarity_top_k, average advantage, average mean, and average log-variance.



## Log metrics to weights & biases

### Subtask:
Add code to log the calculated metrics to Weights & Biases within the training loop. This will typically involve using `wandb.log()`.


**Reasoning**:
Add code to log the calculated metrics to Weights & Biases within the training loop as requested by the subtask. This involves using `wandb.log()` for both batch and epoch metrics.



## Visualize and analyze metrics in weights & biases

### Subtask:
Use the Weights & Biases dashboard to visualize the logged metrics, monitor training progress, and identify potential issues.


## Refine logging and metrics

### Subtask:
Based on the analysis in Weights & Biases, refine the set of metrics being tracked or the logging frequency for better insights.


**Reasoning**:
Based on the analysis in Weights & Biases, refine the set of metrics being tracked and their logging frequency in the training code. The current metrics are informative, but batch-level metrics can be noisy. Epoch-level metrics provide a smoother view of overall progress. We will keep logging both but ensure epoch metrics are clearly distinguished. We can also consider adding the standard deviation of `similarity_top_k` predictions at the epoch level to see how the policy's uncertainty evolves.



In [None]:
!pip install llama-index -q
!pip install langchain -q
!pip install langchain_experimental -q

The training loop includes the actual RAG execution using the OECD_index and the cosine_similarity_reward function to calculate the reward based on the generated answer and the ground truth for each question in the batch. This means the policy network is being trained using real rewards derived from the RAG system's performance.

**Explanation of training results logged in W&B**

**Epoch 27/100, Avg Loss: 0.0183, Avg Reward: 0.3199, Avg Predicted Top K: 2.80, Predicted Top K Std: 0.40**

**Epoch** 27/100: This indicates that the training is currently on the 27th iteration (epoch) out of a planned total of 100 epochs.

**Avg Loss**: 0.0183: This is the average policy loss calculated across all the batches in Epoch 27. In reinforcement learning, the policy loss is a value that the optimization algorithm (like the one used in GRPO) tries to minimize. A **lower loss** generally suggests that the **policy** is being **updated** in a **direction** that is expected to **increase** **rewards**. The value 0.0183 is the magnitude of this average loss for this specific epoch.

**Avg Reward**: 0.3199: This is the **average cosine similarity reward** obtained during Epoch 27. This reward is calculated by running the RAG system with the actions (predicted similarity_top_k values) sampled by the policy network for each question in the dataset and then comparing the generated answers to the ground truth. An **average** **reward** of 0.3199 means that, on **average** across all samples in this epoch, the **generated** **answers** had a **cosine** **similarity** of approximately 0.32 with their **respective** **ground** **truth** **answers**. **Higher** values indicate **better** **performance** in terms of **semantic** **similarity**.

**Avg Predicted Top K**: 2.80: This is the average value of similarity_top_k predicted (or sampled and then processed into an integer) by the policy network across all questions in Epoch 27. This **metric** **tells** you what the **policy** **network** is generally **choosing** for the **number** of **documents** to **retrieve**. An **average** of 2.80 suggests the **policy** is typically **selecting** around 3 **documents**.

**Predicted Top K Std**: 0.40: This is the standard deviation of the similarity_top_k values predicted (or sampled and processed) by the policy network across all questions in Epoch 27. The standard deviation measures the dispersion or spread of the predicted values. A standard deviation of 0.40 indicates that the predicted similarity_top_k values in this epoch were relatively close to the average (2.80), meaning the policy's predictions for similarity_top_k didn't vary widely within this epoch.

## Summary:

### Data Analysis Key Findings

*   The training process successfully logged key metrics at both the batch and epoch levels to Weights & Biases, including policy loss, average reward, average predicted `similarity_top_k`, average advantage, average mean, and average log-variance of the predicted distribution.
*   Hyperparameters such as learning rate (1e-4), batch size (8), and number of epochs (100) were successfully logged to the Weights & Biases config.
*   The training loop executed for 100 epochs, with console output confirming the progress and epoch-average loss and reward.
*   A new epoch-level metric, the standard deviation of the predicted `similarity_top_k`, was successfully added and logged to provide insight into the variability of the policy's actions.

### Completed Steps

*   Visualized the logged metrics in the Weights & Biases dashboard to analyze trends, identify correlations between metrics (e.g., reward and predicted top\_k), and diagnose potential training issues such as instability or convergence problems.
*   Implement the actual RAG system reward calculation to replace the dummy reward function, allowing the policy to learn based on real retrieval performance.


# Task
Extend the provided Python code for training a RAG policy to maintain and train a group of policies using a modified training loop.

## Modify policy network management - GRPO implementation

### Subtask:
Change the code to create and manage a list or collection of `RAGPolicyNetwork` instances instead of just one.


**Reasoning**:
The subtask is to replace the single policy network instance with a list of multiple policy network instances. This involves defining the number of policies, creating a container for them (a list or ModuleList), and instantiating the specified number of policies within a loop. This needs to be done in a code block that replaces the current instantiation of the single policy network.



## Adapt data collection - GRPO Comparing the performance of policies within the group.

### Subtask:
Modify the training loop to collect data (questions, actions, rewards, log probabilities) for each policy in the group over an iteration. This might involve running each policy on the same batch or different batches.


**Reasoning**:
Update the training loop structure to iterate through the `policy_group`, process a batch of data for each policy, perform a forward pass, sample actions and calculate log probabilities, execute the RAG system, calculate rewards, and store the results for each policy.



## Implement group performance evaluation

### Subtask:
Add logic to evaluate the performance of each policy in the group based on the collected rewards. This could be the average reward over the data collected for that policy.


**Reasoning**:
Implement the logic to evaluate the performance of each policy in the group based on the collected rewards within the training loop. This involves calculating the average reward for each policy after collecting data across the dataset for the epoch. Store these average rewards and identify the best-performing policy.



## Adapt policy update rule

### Subtask:
Modify the policy update step to incorporate the group information, potentially using the average reward or relative performance within the group to adjust gradients.


**Reasoning**:
Implement the policy update step for each policy in the group after data collection and evaluation, calculating the individual policy loss, zeroing gradients, performing the backward pass, and updating the optimizer, and logging the policy loss to Weights & Biases.



**Reasoning**:
Implement the Weights & Biases logging for policy-specific and group-level metrics as requested by the subtask. This involves modifying the training loop to collect and log these metrics at the end of each epoch. I will ensure all necessary variables and functions from previous cells are included in this code block to make it runnable.



# Task
Extend the provided Python code to implement a basic version of Grouped Relative Policy Optimization (GRPO). This involves maintaining a group of policies, comparing their performance, and using a trust region-like approach for policy updates within the group. Ensure unnecessary code is removed before proceeding.

## Modify policy network management

### Subtask:
Change the code to create and manage a list or collection of `RAGPolicyNetwork` instances instead of just one.


**Reasoning**:
The subtask is to create and manage a group of policy networks and their optimizers. This involves defining the number of policies, instantiating the `RAGPolicyNetwork` class multiple times, storing them in a list, and creating a corresponding list of optimizers. I will include all necessary imports and definitions from previous successful steps to make this code block runnable and self-contained as requested by the instructions.



# Task
Extend the provided Python code to implement a training loop that maintains a group of policies, compares their performance, and uses a trust region method for policy updates, reflecting the principles of Group Relative Policy Optimization (GRPO).

## Modify policy network management

### Subtask:
Change the code to create and manage a list or collection of `RAGPolicyNetwork` instances instead of just one.


**Reasoning**:
The subtask is to create and manage a group of policy networks and their optimizers. This involves defining the number of policies, instantiating the RAGPolicyNetwork class multiple times, storing them in a list, and creating a corresponding list of optimizers. I will include all necessary imports and definitions from previous successful steps to make this code block runnable and self-contained as requested by the instructions.



## Adapt policy update rule

### Subtask:
Modify the policy update step to incorporate the group information. This is where the "Relative" part of GRPO comes in. The update for each policy might depend on its performance relative to others in the group, or updates might be averaged across the group. (Note: A full GRPO implementation would involve more sophisticated trust region methods based on the group, but for this plan, we'll focus on managing the group and adapting the basic update).


**Reasoning**:
Implement the policy update step for each policy in the group after data collection and evaluation, calculating the individual policy loss, zeroing gradients, performing the backward pass, and updating the optimizer, and logging the policy loss to Weights & Biases.



## Update logging

### Subtask:
Modify Weights & Biases logging to track metrics for each policy in the group, or group-level metrics (e.g., average reward of the best policy, average reward across the group).


# Task
Extend the provided Python code to maintain a group of policies, adapt data collection for each policy in the group, implement group performance evaluation, and adapt the policy update rule to incorporate group information. Ensure Weights & Biases logging is updated to track metrics for each policy or group-level metrics.

## Modify policy network management

### Subtask:
Change the code to create and manage a list or collection of `RAGPolicyNetwork` instances instead of just one.


**Reasoning**:
The subtask is to create and manage a group of policy networks and their optimizers. This involves defining the number of policies, instantiating the RAGPolicyNetwork class multiple times, storing them in a list, and creating a corresponding list of optimizers. I will include all necessary imports and definitions from previous successful steps to make this code block runnable and self-contained as requested by the instructions.



**Reasoning**:
Remove the deprecated OpenAIAssistantAgent import and instantiation from cell `hlIsonI9Hlx6` as requested by the subtask instructions.



**Reasoning**:
Remove the deprecated OpenAIAssistantAgent import and instantiation from cell `COLLlq4wJZZP` as requested by the subtask instructions.



## Define tools for chat completions api

### Subtask:
Adapt the existing `QueryEngineTool` instances to a format compatible with the Chat Completions API's `tools` parameter (list of function definitions).


**Reasoning**:
Define the function `openai_tool_definition` to convert a QueryEngineTool to the OpenAI Chat Completions API tool format and then apply it to the existing tools.



In [None]:
from llama_index.core.tools import QueryEngineTool, ToolMetadata
import json # Import the json module for schema definition

# Redefine or ensure tools are available from previous successful cells
# Assuming OECD_query_tool and Form990_query_tool are defined here or in a prior cell
# Example definition if not already available:
# OECD_engine = OECD_index.as_query_engine(similarity_top_k=3) # Assuming OECD_index is loaded
# form990_guidelines_engine = form990_guidelines_index.as_query_engine(similarity_top_k=3) # Assuming form990_guidelines_index is loaded
# OECD_query_tool = QueryEngineTool(
#                       query_engine=OECD_engine,
#                       metadata=ToolMetadata(
#                           name="OECD_QueryEngineTool_2022",
#                           description="Provides information about Transfer Pricing Guidelines for Organization from OECD for year 2022"
#                       )
#                     )
# Form990_query_tool = QueryEngineTool(
#                       query_engine=form990_guidelines_engine,
#                       metadata=ToolMetadata(
#                           name="form990_2022",
#                           description="Provides information about Form990 filling guidelines for Non-Profit Organization only from the index which was set for Form990_Guidelines.pdf "
#                       )
#                     )
# tools = [OECD_query_tool, Form990_query_tool] # Ensure 'tools' list is defined

# 1. Define a function openai_tool_definition
def openai_tool_definition(query_engine_tool: QueryEngineTool) -> dict:
    """
    Converts a LlamaIndex QueryEngineTool to the OpenAI Chat Completions API
    tool format.

    Args:
        query_engine_tool: The QueryEngineTool instance.

    Returns:
        A dictionary representing the tool definition in OpenAI's format.
    """
    # 4. Define the parameters schema
    parameters_schema = {
        "type": "object",
        "properties": {
            "input": {
                "type": "string",
                "description": "The query string to pass to the tool."
            }
        },
        "required": ["input"],
    }

    # 2. Construct the dictionary for the tool definition
    tool_definition = {
        "type": "function",
        "function": {
            # 3. Extract name and description from metadata
            "name": query_engine_tool.metadata.name,
            "description": query_engine_tool.metadata.description,
            "parameters": parameters_schema,
        },
    }
    return tool_definition

# 5. Apply this function to existing tools
# Ensure 'tools' list is available from previous cells
# If not, you'd need to define OECD_query_tool and Form990_query_tool here
# Assuming 'tools' is defined and contains OECD_query_tool and Form990_query_tool
if 'tools' in globals() and isinstance(tools, list) and len(tools) > 0:
    openai_tools = [openai_tool_definition(tool) for tool in tools]
    print("OpenAI tool definitions created successfully:")
    # Print the created tool definitions for verification
    print(json.dumps(openai_tools, indent=2))
else:
    print("Error: 'tools' list not found or is empty. Cannot create OpenAI tool definitions.")
    openai_tools = [] # Initialize as empty if tools are not available



## Implement chat completions logic

### Subtask:
Create a function or class that uses `openai.ChatCompletion.create` (or the equivalent using the `openai` library) to send messages to the model.


**Reasoning**:
Define a function `chat_with_tools` that uses `openai.ChatCompletion.create` (or the equivalent) to send messages and tools to the model and handle the response.



In [None]:
import openai
import os # Import os to access environment variables

# Ensure the OpenAI API key is set
if "OPENAI_API_KEY" not in os.environ:
    print("Error: OPENAI_API_KEY environment variable not set.")
    # In a real application, you would handle this more robustly
    # For this example, we'll proceed but the API call will fail.


# 1. Define a Python function, for example chat_with_tools
def chat_with_tools(messages: list[dict], openai_tools: list[dict], model: str = "gpt-4o-mini", temperature: float = 0.7) -> dict:
    """
    Sends messages and OpenAI tool definitions to the Chat Completions API
    and returns the response.

    Args:
        messages: A list of message dictionaries in the OpenAI format.
        openai_tools: A list of tool definitions in the OpenAI format.
        model: The name of the OpenAI model to use (default: gpt-4o-mini).
        temperature: The sampling temperature (default: 0.7).

    Returns:
        A dictionary representing the response from the OpenAI API.
        Returns an empty dictionary and prints an error message if the API call fails.
    """
    try:
        # 2. Inside the function, use the openai.ChatCompletion.create method
        # Using the newer openai library syntax (v1.0+)
        # The older openai.ChatCompletion.create is deprecated in favor of client.chat.completions.create
        # Let's use the newer client syntax
        client = openai.OpenAI() # Assumes OPENAI_API_KEY is set as an environment variable

        response = client.chat.completions.create(
            model=model,
            messages=messages,
            tools=openai_tools,
            tool_choice="auto",  # auto lets the model decide whether to call a tool or respond
            temperature=temperature,
        )
        # The response object structure is slightly different in the new client
        # Returning the full response object for now
        return response

    # 3. Include error handling (e.g., a try...except block) for the API call.
    except openai.APIError as e:
        print(f"OpenAI API error: {e}")
        return {}
    except Exception as e:
        print(f"An unexpected error occurred during the OpenAI API call: {e}")
        return {}

# Example usage (assuming 'openai_tools' is defined from the previous step)
# and 'messages' is a list of dictionaries like [{"role": "user", "content": "Your question here"}]
# Example messages setup:
# messages_history = [{"role": "user", "content": "What does Articles 9 of the OECD Model Tax Convention state?"}]
# if 'openai_tools' in globals() and openai_tools:
#     print("Attempting to call chat_with_tools...")
#     api_response = chat_with_tools(messages_history, openai_tools)
#     print("\nAPI Response:")
#     print(api_response)
# else:
#     print("openai_tools not defined. Skipping chat_with_tools example call.")


## Implement tool calling handling

### Subtask:
Within the chat loop, check the model's response for `tool_calls`. If present, execute the corresponding tool (query the appropriate LlamaIndex engine) and send the tool output back to the model.


**Reasoning**:
Implement the main chat loop that iteratively calls the chat_with_tools function, checks for tool_calls in the response, executes the tools using the corresponding LlamaIndex query engines, appends the tool outputs to the messages history, and continues the conversation until the model provides a final answer without tool calls. This combines steps 1-10 of the instructions.



In [None]:
import json # Import json if not already imported in this block
import time # Import time for potential delays
import openai # Import openai

# Assume necessary variables and functions are defined from previous successful cells:
# OECD_index, form990_guidelines_index (if used), tools (LlamaIndex QueryEngineTool instances),
# openai_tools (OpenAI tool definitions), chat_with_tools, cosine_similarity_reward (if needed elsewhere)
# Settings (LlamaIndex) are assumed to be configured.

# Ensure query engines corresponding to the tools are accessible
# Assuming they are defined from prior steps, e.g.:
# OECD_engine = OECD_index.as_query_engine(similarity_top_k=3)
# form990_guidelines_engine = form990_guidelines_index.as_query_engine(similarity_top_k=3)

# Create a dictionary mapping tool names to query engines for easy lookup
# Ensure OECD_engine and form990_guidelines_engine are defined and loaded from index
query_engine_map = {}
if 'OECD_index' in globals() and OECD_index is not None:
    try:
        query_engine_map["OECD_QueryEngineTool_2022"] = OECD_index.as_query_engine(similarity_top_k=3)
        print("Created OECD query engine.")
    except Exception as e:
        print(f"Error creating OECD query engine: {e}")
else:
    print("OECD_index not found or loaded. OECD tool execution will fail.")

# Assuming form990_guidelines_index is also loaded and available
if 'form990_guidelines_index' in globals() and form990_guidelines_index is not None:
     try:
          query_engine_map["form990_2022"] = form990_guidelines_index.as_query_engine(similarity_top_k=3)
          print("Created Form990 query engine.")
     except Exception as e:
          print(f"Error creating Form990 query engine: {e}")
else:
     print("form990_guidelines_index not found or loaded. Form990 tool execution will fail.")


# 1. Initialize the conversation with a system message and the user's first question
messages = [{"role": "system", "content": "You are an assistant that provides answers to questions on OECD and Form990 using the available tools. Answer as accurately as possible based on the tool outputs. Whenever there is comparison make sure the results are in side by side comparison table with headers and add links to the document."}]
user_question = "What is Form990 EZ and when should an organiaztion complete Form990 EZ form? And how is it different from Schedule H? Can you show the results in side by side comparison table with headers and also link to the document?"
# user_question = "What does Articles 9 and 25 of the OECD Model Tax Convention state?" # Example question for OECD tool
messages.append({"role": "user", "content": user_question})

# Ensure openai_tools is defined from the previous step (conversion of LlamaIndex tools)
if 'openai_tools' not in globals() or not openai_tools:
    print("OpenAI tool definitions ('openai_tools') not found or are empty. Cannot proceed with tool-using chat loop.")
    # Define dummy tools to prevent crash if previous cell failed, but tool calls won't work
    openai_tools = [{"type": "function", "function": {"name": "dummy_tool", "description": "A dummy tool.", "parameters": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]}}}]


# --- Main Chat Loop ---
print(f"Starting chat loop for question: {user_question}")

# Keep track of the number of turns to prevent infinite loops
max_turns = 10
turn_count = 0
final_response_content = None # Variable to store the final answer from the model

while turn_count < max_turns:
    turn_count += 1
    print(f"\n--- Turn {turn_count} ---")

    # 2. Send messages to the model using chat_with_tools
    # Ensure chat_with_tools function is defined from a previous cell
    api_response = chat_with_tools(messages, openai_tools)

    # Check if the API call was successful and has choices
    if not api_response or not hasattr(api_response, 'choices') or not api_response.choices:
        print("API call failed or returned no choices.")
        break # Exit loop if API call fails

    # Extract the message from the response
    response_message = api_response.choices[0].message
    print(f"Model response received (Role: {response_message.role})")

    # 3. Check if the response contains tool_calls
    tool_calls = response_message.tool_calls

    if tool_calls:
        print("Model requested tool calls.")
        # Append the model's message (requesting tools) to the messages history
        messages.append(response_message)

        # 4. Iterate through each tool call
        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_args_str = tool_call.function.arguments
            tool_call_id = tool_call.id # Get the tool call ID

            print(f"  Tool call requested: {function_name} with args: {function_args_str}")

            # 5. Parse the arguments string
            try:
                function_args = json.loads(function_args_str)
                tool_input = function_args.get("input") # Extract the input argument
                if tool_input is None:
                     print(f"Warning: 'input' argument not found in tool call args for {function_name}.")
                     tool_input = "" # Use empty string if input is missing

            except json.JSONDecodeError:
                print(f"Error decoding tool call arguments JSON for {function_name}: {function_args_str}")
                tool_input = "" # Use empty string or handle as error

            # 6. Identify and execute the corresponding LlamaIndex query engine
            query_engine = query_engine_map.get(function_name)
            tool_output = "" # Initialize tool output

            if query_engine:
                print(f"  Executing tool: {function_name} with input: '{tool_input}'")
                try:
                    # Execute the query using the LlamaIndex engine
                    llama_response = query_engine.query(tool_input)
                    tool_output = str(llama_response) # Convert response to string
                    print(f"  Tool execution successful. Output snippet: '{tool_output[:100]}...'")
                except Exception as e:
                    print(f"  Error executing LlamaIndex tool '{function_name}': {e}")
                    tool_output = f"Error executing tool: {e}" # Provide error message as tool output
            else:
                print(f"  Error: No LlamaIndex query engine found for tool name '{function_name}'.")
                tool_output = f"Error: Tool '{function_name}' not supported or found."

            # 7. Format the output from the LlamaIndex query engine into the required format for the API
            tool_message = {
                "role": "tool",
                "tool_call_id": tool_call_id, # Link the tool output to the specific tool call
                "content": tool_output,
            }

            # 8. Append this tool message to the messages history
            messages.append(tool_message)
            print("  Tool output appended to messages history.")

        # 9. Send the updated messages history back to the chat_with_tools function
        # The loop continues, and the next iteration will send the updated 'messages'
        # The model will then process the tool outputs and generate a response.

    else:
        # 10. If the model responds without tool_calls, it's the final answer
        print("Model responded without tool calls. This is the final answer.")
        final_response_content = response_message.content
        messages.append(response_message) # Append the final response to history
        break # Exit the loop

    # Add a small delay to avoid hitting rate limits too quickly during development
    time.sleep(1)

# After the loop, print the final answer if available
if final_response_content:
    print("\n--- Final Answer ---")
    print(final_response_content)
elif turn_count >= max_turns:
    print("\n--- Chat loop ended due to reaching max turns ---")
    # Optionally print the last message from the model
    if messages:
        print("Last message from model:")
        print(messages[-1])


print("\nChat loop finished.")


## Handle final response

### Subtask:
Process the model's final response after tool execution to extract the answer.


**Reasoning**:
Extract the final answer string from the `content` attribute of the `response_message` object after the loop concludes, and store it in `final_answer_content`. Print this variable.



In [None]:
# Assume the chat loop from the previous step has just completed.
# The `response_message` variable holds the last message received from the model.
# The `final_response_content` variable was intended to store the final answer.
# The `messages` list holds the entire conversation history.

# Check if the loop terminated because a final answer was received (i.e., no tool calls in the last message)
# and if the last message is not empty.
if response_message and not response_message.tool_calls and response_message.content:
    final_answer_content = response_message.content
    print("\n--- Final Answer (Extracted after loop) ---")
    print(final_answer_content)
elif turn_count >= max_turns:
    print("\n--- Chat loop ended due to reaching max turns ---")
    # If max turns reached, the last message might contain a partial answer or just model thoughts.
    # We can still try to extract content if available, but it might not be a complete final answer.
    if response_message and response_message.content:
         final_answer_content = response_message.content
         print("Last message content:")
         print(final_answer_content)
    else:
         final_answer_content = "Chat loop ended without a final answer."
         print(final_answer_content)
else:
    # Handle other potential loop exit conditions or errors
    final_answer_content = "Chat loop terminated unexpectedly without a final answer."
    print(final_answer_content)

# The final_answer_content variable now holds the extracted final answer or an informative message.
# It can be used for further processing or evaluation.

# Note: The previous cell's code already included printing logic after the loop.
# This cell explicitly focuses on ensuring the extraction and storage in `final_answer_content`.
# The printing is included here to demonstrate the result of the extraction.

## Update agent usage

### Subtask:
Replace the `agent.chat()` calls with calls to your new Chat Completions based function/class.


**Reasoning**:
Replace the existing agent.chat() calls in the training loop with calls to the new chat_with_tools function and integrate the logic to collect the final answer and the contexts retrieved during the tool execution.



**Reasoning**:
The previous command failed because `Dataset` was not defined. This indicates that the import `from torch.utils.data import Dataset, DataLoader` was either not executed correctly or was not included in the necessary cell dependencies. I will add the import for `Dataset` and `DataLoader` again to ensure they are available.



**Reasoning**:
The previous command failed because `policy_group` was not defined. This indicates that the code block responsible for creating the list of policy networks and their optimizers was not executed or was not included in the necessary dependencies. I will add the code to define `NUM_POLICIES`, instantiate the `RAGPolicyNetwork` instances, create the `policy_group` list, and create the corresponding `optimizers` list to ensure they are available before the training loop begins.



**Reasoning**:
The previous command failed because `LEARNING_RATE` was not defined when trying to instantiate the optimizers. I need to ensure `LEARNING_RATE` is defined before the line where optimizers are created. I will add the definition for `LEARNING_RATE` and ensure all necessary imports and variable definitions from previous successful steps are present in this code block to make it runnable and self-contained.



**Reasoning**:
The previous command failed because `optimizers` was not defined in the policy update phase. This is happening because the code block that defines and instantiates `policy_group` and `optimizers` is guarded by an `if 'policy_group' not in globals()` check. If the cell was run before and `policy_group` exists, this block is skipped, and `optimizers` is not defined in the current execution scope, even if it was defined in a *previous* execution of the cell. To fix this, I need to ensure `optimizers` is always defined when the training loop starts, regardless of whether the policy group was just instantiated or already existed. I will move the optimizer instantiation outside the `if/else` block for policy group instantiation, but still within the section that handles policy network setup, ensuring `LEARNING_RATE` and `policy_group` are available.



**Reasoning**:
The previous command failed because `LEARNING_RATE` was not defined when creating the optimizers. This is happening because the definition of `LEARNING_RATE` was inside an `if` block that might have been skipped. I need to ensure `LEARNING_RATE` is defined unconditionally before it is used to create the optimizers. I will move the definition of `LEARNING_RATE` to the top of the cell, along with other essential variable definitions, ensuring it is always available. I will also include all necessary imports and variable definitions from previous successful steps to make this code block runnable and self-contained.



In [None]:
import json
import time
import openai
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Normal
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import wandb
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex, SummaryIndex, StorageContext, load_index_from_storage
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from google.colab import userdata # Import userdata

# Redefine necessary variables and functions from previous cells to ensure scope

# Assuming OPENAI_API_KEY is already set as an environment variable
if "OPENAI_API_KEY" not in os.environ:
    # Attempt to get from userdata if not in environment variables
    api_key = userdata.get('OPENAI_API_KEY', default=None)
    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key
        print("Set OPENAI_API_KEY from Colab userdata.")
    else:
        print("Error: OPENAI_API_KEY environment variable and Colab secret not set.")


# Setup OpenAI Model and Embeddings - Ensure these are set within this cell's execution
# Check if Settings is already configured to avoid redundant calls if cell is re-run
if not hasattr(Settings, '_llm') or Settings.llm is None:
    Settings.llm = OpenAI(model='gpt-4o-mini', temperature=0.2)
if not hasattr(Settings, '_embed_model') or Settings.embed_model is None:
    Settings.embed_model = OpenAIEmbedding(model='text-embedding-3-small')
if not hasattr(Settings, '_chunk_size') or Settings.chunk_size != 1024:
    Settings.chunk_size = 1024
print("LlamaIndex Settings configured.")


# Assuming Google Drive is mounted at /content/drive and data_dir is defined
if 'data_dir' not in globals():
     data_dir = '/content/drive/MyDrive' # Define if not already
if 'PERSIST_INDEX_DIR' not in globals():
     PERSIST_INDEX_DIR = f"/{data_dir}/RAG/data/" # Define if not already


# Redefine get_index function to ensure it's available and handles creation
def get_index(index_name, doc_file_path):
  index = None
  full_index_dir = f"{PERSIST_INDEX_DIR}{index_name}/"
  if not os.path.exists(full_index_dir):
    print(f"Index not found at {full_index_dir}. Attempting to create index...")
    # Load the documents
    try:
        documents = SimpleDirectoryReader(input_files=[doc_file_path]).load_data()
        print(f"Loaded documents from {doc_file_path}.")
        index = VectorStoreIndex.from_documents(documents)
        print("Created VectorStoreIndex.")
        # Store the index to disk
        os.makedirs(full_index_dir, exist_ok=True) # Ensure directory exists
        index.storage_context.persist(full_index_dir)
        print(f"Created and persisted index at {full_index_dir}")
    except FileNotFoundError:
        print(f"Error: Document file not found at {doc_file_path}. Cannot create index.")
        return None # Return None if document is not found
    except Exception as e:
        print(f"An error occurred during index creation: {e}")
        return None
  else: # Load index from disk
    print(f"Loading index from storage at {full_index_dir}")
    try:
        storage_context = StorageContext.from_defaults(persist_dir=full_index_dir)
        index = load_index_from_storage(storage_context)
        print("Loaded index from storage.")
    except Exception as e:
        print(f"An error occurred during index loading from {full_index_dir}: {e}")
        return None

  return index

# Load or create the OECD index using the redefined get_index function
# Ensure the path to the document is correct and the file exists
oecd_doc_path = f"{data_dir}/RAG/data/OECD/OECD_Transfer_Pricing_Guidelines.pdf"
OECD_index = get_index("OECDTPGuidelines", oecd_doc_path)

# Assuming form990_guidelines_index is also needed and defined from previous cells
# Add similar loading/creation logic for form990_guidelines_index if it's used by the tools
form990_doc_path = f"{data_dir}/RAG/data/Form990/Form990_Guidelines.pdf"
form990_guidelines_index = get_index("Form990Guidelines", form990_doc_path)


# Redefine cosine_similarity_reward function if not available
if 'cosine_similarity_reward' not in globals():
    def cosine_similarity_reward(retrieved_context, ground_truth):
        """
        Calculates a reward based on cosine similarity between the retrieved context
        and the ground truth using TF-IDF vectorization.
        """
        if not retrieved_context or not ground_truth:
            return 0.0

        # Handle case where one string is empty but the other isn't
        if not retrieved_context or not ground_truth:
             return 0.0 # Or some minimal penalty like 0.1 if one is empty

        # Create TF-IDF vectors
        vectorizer = TfidfVectorizer().fit([retrieved_context, ground_truth])
        vectors = vectorizer.transform([retrieved_context, ground_truth])

        # Calculate cosine similarity
        # Handle potential division by zero if vectors are zero vectors (e.g., empty strings after tokenization)
        if vectors[0].sum() == 0 or vectors[1].sum() == 0:
            return 0.0

        similarity_score = cosine_similarity(vectors[0], vectors[1])[0][0]

        return similarity_score

# Redefine sample_action_and_continuous function if not available
if 'sample_action_and_continuous' not in globals():
    def sample_action_and_continuous(mean, log_variance):
        std_dev = torch.exp(0.5 * log_variance)
        distribution = Normal(mean, std_dev)
        continuous_sample = distribution.sample()
        # Ensure action is a positive integer
        processed_action = torch.max(torch.tensor(1.0), torch.round(torch.abs(continuous_sample)))
        return processed_action, continuous_sample

# Redefine calculate_baseline function if not available
if 'calculate_baseline' not in globals():
    def calculate_baseline(rewards):
        if isinstance(rewards, list):
            rewards = torch.tensor(rewards, dtype=torch.float32)
        if rewards.numel() == 0:
            return 0.0
        return torch.mean(rewards)

# Redefine calculate_log_prob function if not available
if 'calculate_log_prob' not in globals():
    def calculate_log_prob(mean, log_variance, action):
        std_dev = torch.exp(0.5 * log_variance)
        distribution = Normal(mean, std_dev)
        log_prob = distribution.log_prob(action)
        return log_prob

# Redefine RAGPolicyNetwork class if not available
if 'RAGPolicyNetwork' not in globals():
    class RAGPolicyNetwork(nn.Module):
        def __init__(self, transformer_model_name="bert-base-uncased", output_dim=2):
            super(RAGPolicyNetwork, self).__init__()
            self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
            self.transformer = AutoModel.from_pretrained(transformer_model_name)
            transformer_output_dim = self.transformer.config.hidden_size
            self.output_layer = nn.Linear(transformer_output_dim, output_dim)

        def forward(self, questions):
            encoded_input = self.tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
            outputs = self.transformer(**encoded_input)
            pooled_output = outputs.pooler_output
            mean_and_log_variance = self.output_layer(pooled_output)
            mean = mean_and_log_variance[:, 0]
            log_variance = mean_and_log_variance[:, 1]
            return mean, log_variance

# Redefine questions and ground truth if not available
if 'questions' not in globals() or 'ground_truth' not in globals():
    questions = ["What does Articles 9 of the OECD Model Tax Convention state?",
                 "What does Articles 25 of the OECD Model Tax Convention state?",
                 "What does Allocation of Taxing Rights mean in OECD Model Tax Convention state?",
                 "How is Mutual Agreement Procedure(MAP) help in resolving disputes between countries when there's a conflict in interpreting the treaty?",
                 "As per OECD Model Tax Convention States what does Residence and Source Country mean?"]
    ground_truth = ["addresses corresponding adjustments in transfer pricing",
                    "outlines the mutual agreement procedure, which resolves disputes related to the application of double tax conventions.",
                    "principles that determine how different jurisdictions can tax income generated by multinational enterprises (MNEs).",
                    "serves as a mechanism for tax administrations to consult and resolve disputes related to the interpretation and application of double tax conventions. It is particularly useful in situations where there is taxation not in accordance with the provisions of the Convention.",
                    "Resident country: The country where the taxpayer lives, Source country: The country where the income originates may also have taxing rights but often with limits."]


# Redefine Dataset and DataLoader if not available
if 'RAGDataset' not in globals() or 'train_dataloader' not in globals():
    class RAGDataset(Dataset):
        def __init__(self, questions, ground_truth):
            self.questions = questions
            self.ground_truth = ground_truth
        def __len__(self):
            return len(self.questions)
        def __getitem__(self, idx):
            return self.questions[idx], self.ground_truth[idx]

    rag_dataset = RAGDataset(questions, ground_truth)
    BATCH_SIZE = 8 # Define BATCH_SIZE if not already
    train_dataloader = DataLoader(rag_dataset, batch_size=BATCH_SIZE, shuffle=True)
    NUM_EPOCHS = 100 # Define NUM_EPOCHS if not already
    LEARNING_RATE = 1e-4 # Define LEARNING_RATE if not already


# Re-instantiate policy_group and optimizers if not available (important for fresh run)
if 'policy_group' not in globals():
    NUM_POLICIES = 5 # Define NUM_POLICIES if not already
    policy_group = nn.ModuleList()
    for i in range(NUM_POLICIES):
        policy = RAGPolicyNetwork(transformer_model_name="bert-base-uncased")
        policy_group.append(policy)
    print(f"Re-instantiated a group of {NUM_POLICIES} RAGPolicyNetwork instances.")
elif len(policy_group) != NUM_POLICIES:
     print(f"Policy group size mismatch. Re-instantiating {NUM_POLICIES} policies.")
     policy_group = nn.ModuleList()
     for i in range(NUM_POLICIES):
        policy = RAGPolicyNetwork(transformer_model_name="bert-base-uncased")
        policy_group.append(policy)
     print(f"Re-instantiated a group of {NUM_POLICIES} RAGPolicyNetwork instances.")
else:
    print(f"Policy group with {NUM_POLICIES} instances already exists.")

# Ensure optimizers are defined for each policy
if 'optimizers' not in globals() or len(optimizers) != NUM_POLICIES:
     print(f"Optimizers not defined or count mismatch. Creating {NUM_POLICIES} optimizers.")
     optimizers = [optim.Adam(policy.parameters(), lr=LEARNING_RATE) for policy in policy_group]
else:
    print(f"Optimizers for {NUM_POLICIES} policies already exists.")


# Redefine tools (LlamaIndex QueryEngineTool instances) and openai_tools (OpenAI definitions)
# Assuming these were defined in previous cells and need to be available here.
# If not defined, create dummy tools or add logic to load/create them.
if 'tools' not in globals() or not tools:
    print("LlamaIndex tools not found. Creating tools from loaded indices.")
    tools = []
    if OECD_index is not None:
        oecd_engine = OECD_index.as_query_engine(similarity_top_k=3)
        oecd_tool = QueryEngineTool(query_engine=oecd_engine, metadata=ToolMetadata(name="OECD_QueryEngineTool_2022", description="Provides information about Transfer Pricing Guidelines for Organization from OECD for year 2022"))
        tools.append(oecd_tool)
    if form990_guidelines_index is not None:
        form990_engine = form990_guidelines_index.as_query_engine(similarity_top_k=3)
        form990_tool = QueryEngineTool(query_engine=form990_engine, metadata=ToolMetadata(name="form990_2022", description="Provides information about Form990 filling guidelines for Non-Profit Organization only from the index which was set for Form990_Guidelines.pdf "))
        tools.append(form990_tool)

    if tools:
        print(f"Created {len(tools)} LlamaIndex tools.")
        # Create openai_tools from LlamaIndex tools
        openai_tools = [openai_tool_definition(tool) for tool in tools]
        print("Created OpenAI tool definitions.")
    else:
        print("Could not create any LlamaIndex tools. Tool execution will likely fail.")
        openai_tools = [] # Ensure openai_tools is an empty list if no tools were created

elif 'openai_tools' not in globals() or not openai_tools:
    print("OpenAI tool definitions not found. Creating from existing LlamaIndex tools.")
    openai_tools = [openai_tool_definition(tool) for tool in tools]
    print("Created OpenAI tool definitions from existing LlamaIndex tools.")
else:
    print("LlamaIndex tools and OpenAI tool definitions already exist.")


# Redefine query_engine_map if not available or needs update
if 'query_engine_map' not in globals() or not query_engine_map:
    print("Query engine map not found or is empty. Creating map.")
    query_engine_map = {}
    if 'tools' in globals() and tools:
        for tool in tools:
            # Assuming the tool's query_engine is accessible
            query_engine_map[tool.metadata.name] = tool.query_engine
        print(f"Created query engine map with {len(query_engine_map)} entries.")
    else:
        print("LlamaIndex tools not available to create query engine map.")
else:
    print("Query engine map already exists.")


# Initialize a Weights & Biases run (if not already initialized and active)
# Use reinit=True to allow re-initialization in a notebook environment
if wandb.run is None: # Corrected check for initialization
    try:
        # Ensure wandb is imported and initialized before logging
        if 'wandb' not in globals():
             import wandb
        wandb.init(project="rag-policy-training", name="grpo-cosine-similarity-group-update-chat-api", reinit=True)

        # Define and log hyperparameters
        config = {
            "learning_rate": LEARNING_RATE,
            "batch_size": BATCH_SIZE,
            "num_epochs": NUM_EPOCHS,
            "transformer_model": "bert-base-uncased",
            "output_dim": 2,
            "num_policies": NUM_POLICIES,
            "max_chat_turns": 5 # Log max chat turns
        }
        wandb.config.update(config)
        print("Training hyperparameters logged to Weights & Biases config.")
    except Exception as e:
        print(f"Error initializing Weights & Biases: {e}")
        print("Weights & Biases logging will be skipped.")
elif wandb.run is not None:
    print(f"Weights & Biases run '{wandb.run.name}' is already active.")
    # Optionally update config if needed, though reinit=True handles this to some extent
    # wandb.config.update(config, allow_val_change=True)


# --- Training Loop ---
print("Starting policy group training with Chat Completions API...")

# Calculate total steps for logging (already done, but ensure variable exists)
if 'total_steps' not in globals() and 'NUM_EPOCHS' in globals() and 'train_dataloader' in globals() and 'NUM_POLICIES' in globals():
     total_steps = NUM_EPOCHS * len(train_dataloader) * NUM_POLICIES
if 'global_step' not in globals():
     global_step = 0
if 'NUM_EPOCHS' not in globals():
     NUM_EPOCHS = 100 # Define if not already
if 'MAX_GRAD_NORM' not in globals():
     MAX_GRAD_NORM = 0.5 # Define if not already


# Define max turns for the chat conversation within the training loop
max_turns_chat = 5 # Limit the number of turns for each question's chat interaction


# Check if OECD_index was loaded successfully before starting training
if OECD_index is not None and form990_guidelines_index is not None: # Ensure both indices are loaded if both tools are intended
    for epoch in range(NUM_EPOCHS):
        # Data structures to collect data across policies for this iteration/epoch
        all_policy_rewards = {}
        all_policy_log_probs = {}
        all_policy_sampled_k_processed = {}
        all_policy_advantages = {}
        all_policy_means = {}
        all_policy_log_variances = {}
        all_policy_losses = {} # Store losses for logging per policy
        collected_answers = {} # Store generated answers
        collected_contexts = {} # Store retrieved contexts

        # --- Data Collection Phase ---
        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Collecting data using Chat Completions API...")
        for policy_idx, policy in enumerate(policy_group):
            policy.train() # Set policy to training mode
            policy_name = f"policy_{policy_idx}"

            # Initialize storage for current policy's data
            all_policy_rewards[policy_name] = []
            all_policy_log_probs[policy_name] = []
            all_policy_sampled_k_processed[policy_name] = []
            all_policy_means[policy_name] = []
            all_policy_log_variances[policy_name] = []
            all_policy_losses[policy_name] = [] # Initialize loss storage
            collected_answers[policy_name] = []
            collected_contexts[policy_name] = []


            # Process the entire dataset for the current policy to collect data
            for batch_idx, (batch_questions, batch_ground_truth) in enumerate(train_dataloader):
                if not batch_questions:
                    continue # Skip empty batches

                # global_step += 1 # Decide if global step increments per batch or per policy pass over data
                                 # Let's increment per batch processed by any policy for overall progress tracking later

                # Perform a forward pass through the policy network for the batch
                mean_output, log_variance_output = policy(list(batch_questions))

                for i in range(len(batch_questions)):
                    question = batch_questions[i]
                    ground_truth_answer = batch_ground_truth[i]

                    # b. Use the sample_action_and_continuous function to sample similarity_top_k actions
                    sampled_k_processed_item, sampled_k_continuous_item = sample_action_and_continuous(mean_output[i], log_variance_output[i])

                    # Store sampled action and log probability for this sample
                    all_policy_sampled_k_processed[policy_name].append(sampled_k_processed_item.item())
                    all_policy_log_probs[policy_name].append(calculate_log_prob(mean_output[i], log_variance_output[i], sampled_k_continuous_item).item())
                    all_policy_means[policy_name].append(mean_output[i].item())
                    all_policy_log_variances[policy_name].append(log_variance_output[i].item())

                    # --- Integrate Actual RAG Execution using Chat Completions API ---
                    predicted_top_k_int = max(1, int(sampled_k_processed_item.item())) # Ensure it's at least 1

                    # Initialize messages for the chat conversation for this specific question
                    messages = [{"role": "system", "content": "You are an assistant that provides answers to questions on OECD and Form990 using the available tools. Answer as accurately as possible based on the tool outputs. Whenever there is comparison make sure the results are in side by side comparison table with headers and add links to the document."}]
                    messages.append({"role": "user", "content": question})

                    chat_turn_count = 0
                    generated_answer = ""
                    retrieved_contexts_for_sample = [] # Store contexts for this sample

                    while chat_turn_count < max_turns_chat:
                        chat_turn_count += 1
                        # print(f"    Chat Turn {chat_turn_count}") # Too verbose

                        # 2. Send messages to the model using chat_with_tools
                        api_response = chat_with_tools(messages, openai_tools)

                        if not api_response or not hasattr(api_response, 'choices') or not api_response.choices:
                            print(f"    API call failed or returned no choices for question: {question}")
                            break

                        response_message = api_response.choices[0].message
                        messages.append(response_message) # Append model's response to history

                        tool_calls = response_message.tool_calls

                        if tool_calls:
                            # 4. Iterate through each tool call
                            for tool_call in tool_calls:
                                function_name = tool_call.function.name
                                function_args_str = tool_call.function.arguments
                                tool_call_id = tool_call.id

                                # print(f"      Tool call requested: {function_name}") # Too verbose

                                # 5. Parse the arguments string
                                try:
                                    function_args = json.loads(function_args_str)
                                    tool_input = function_args.get("input")
                                    if tool_input is None:
                                         print(f"      Warning: 'input' arg missing for {function_name}.")
                                         tool_input = ""

                                except json.JSONDecodeError:
                                    print(f"      Error decoding args JSON for {function_name}: {function_args_str}")
                                    tool_input = ""

                                # 6. Identify and execute the corresponding LlamaIndex query engine
                                # Use the policy-predicted similarity_top_k for the query engine
                                query_engine = query_engine_map.get(function_name)
                                tool_output = ""

                                if query_engine:
                                    # Temporarily set similarity_top_k for this query engine if it supports it
                                    original_similarity_top_k = getattr(query_engine.retriever, 'similarity_top_k', None)
                                    if original_similarity_top_k is not None:
                                         query_engine.retriever.similarity_top_k = predicted_top_k_int
                                         # print(f"      Set similarity_top_k to {predicted_top_k_int} for {function_name}") # Too verbose

                                    try:
                                        llama_response = query_engine.query(tool_input)
                                        tool_output = str(llama_response)
                                        # Collect the retrieved nodes' text as context
                                        if hasattr(llama_response, 'source_nodes'):
                                             retrieved_contexts_for_sample.extend([node.node.text for node in llama_response.source_nodes])

                                        # print(f"      Tool execution successful.") # Too verbose
                                    except Exception as e:
                                        print(f"      Error executing LlamaIndex tool '{function_name}': {e}")
                                        tool_output = f"Error executing tool: {e}"

                                    finally:
                                         # Restore original similarity_top_k if it was modified
                                         if original_similarity_top_k is not None:
                                              query_engine.retriever.similarity_top_k = original_similarity_top_k


                                else:
                                    print(f"      Error: No LlamaIndex query engine found for tool '{function_name}'.")
                                    tool_output = f"Error: Tool '{function_name}' not supported or found."

                                # 7. Format the output and 8. Append to messages history
                                tool_message = {
                                    "role": "tool",
                                    "tool_call_id": tool_call_id,
                                    "content": tool_output,
                                }
                                messages.append(tool_message)
                                # print("      Tool output appended.") # Too verbose

                            # 9. The loop continues, sending updated messages in the next turn
                        else:
                            # 10. If no tool_calls, it's the final answer
                            # print("    Model provided final answer.") # Too verbose
                            generated_answer = response_message.content
                            break # Exit chat loop for this question

                    # --- End of Chat Conversation for one question ---

                    # Store generated answer and collected contexts for this sample
                    collected_answers[policy_name].append(generated_answer)
                    collected_contexts[policy_name].append(retrieved_contexts_for_sample)


                    # Calculate the cosine similarity reward for the generated answer
                    # Use the combined retrieved contexts for reward calculation baseline if needed,
                    # but the reward is typically based on the *generated answer* vs *ground truth*.
                    # The collected contexts are for evaluation/analysis, not directly for the reward calculation here.
                    # The reward function already compares generated_answer and ground_truth_answer.
                    reward = cosine_similarity_reward(generated_answer, ground_truth_answer)
                    all_policy_rewards[policy_name].append(reward)

                # Increment global step once per batch processed by any policy
                global_step += 1


        # --- End of Policy Data Collection within Epoch ---

        # --- Implement Group Performance Evaluation and Logging ---
        policy_avg_rewards = {}
        best_policy_name = None
        highest_avg_reward = -float('inf') # Initialize with negative infinity

        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Evaluating policy performance and logging epoch metrics...")

        # Data structure to store epoch metrics for logging
        epoch_metrics = {}
        group_avg_reward = 0.0
        total_valid_rewards = 0

        for policy_idx, policy in enumerate(policy_group):
            policy_name = f"policy_{policy_idx}"
            epoch_rewards = all_policy_rewards[policy_name]

            # 1. Calculate the average reward for each policy
            avg_epoch_reward = np.mean(epoch_rewards) if epoch_rewards else 0.0
            policy_avg_rewards[policy_name] = avg_epoch_reward

            # 3. Identify the policy with the highest average reward
            if avg_epoch_reward > highest_avg_reward:
                highest_avg_reward = avg_epoch_reward
                best_policy_name = policy_name

            # Also calculate other epoch metrics for logging
            epoch_predicted_k = all_policy_sampled_k_processed[policy_name]
            # Recalculate advantages based on the actual epoch rewards collected
            epoch_rewards_tensor = torch.tensor(epoch_rewards, dtype=torch.float32)
            epoch_advantages = (epoch_rewards_tensor - calculate_baseline(epoch_rewards_tensor)).tolist() # Use individual policy's baseline for advantage calculation

            epoch_means = all_policy_means[policy_name]
            epoch_log_variances = all_policy_log_variances[policy_name]

            avg_epoch_predicted_top_k = np.mean(epoch_predicted_k) if epoch_predicted_k else 0
            epoch_predicted_top_k_std = np.std(epoch_predicted_k) if epoch_predicted_k else 0
            avg_epoch_advantage = np.mean(epoch_advantages) if epoch_advantages else 0
            avg_epoch_mean = np.mean(epoch_means) if epoch_means else 0
            avg_epoch_log_variance = np.mean(epoch_log_variances) if epoch_log_variances else 0

            # Store policy-specific epoch metrics for logging
            epoch_metrics[f"{policy_name}/epoch_average_reward"] = avg_epoch_reward
            epoch_metrics[f"{policy_name}/epoch_average_predicted_top_k"] = avg_epoch_predicted_top_k
            epoch_metrics[f"{policy_name}/epoch_predicted_top_k_std"] = epoch_predicted_top_k_std
            epoch_metrics[f"{policy_name}/epoch_average_advantage"] = avg_epoch_advantage
            epoch_metrics[f"{policy_name}/epoch_average_mean"] = avg_epoch_mean
            epoch_metrics[f"{policy_name}/epoch_average_log_variance"] = avg_epoch_log_variance


            # Accumulate reward for group average calculation
            group_avg_reward += np.sum(epoch_rewards)
            total_valid_rewards += len(epoch_rewards) # Sum of samples across all policies


            # 4. Print or store the average rewards for each policy
            print(f"    {policy_name}: Avg Reward = {avg_epoch_reward:.4f}, Avg Predicted Top K = {avg_epoch_predicted_top_k:.2f}, Predicted Top K Std = {epoch_predicted_top_k_std:.2f}")


        # Calculate group-level average reward across all policies
        group_avg_reward = group_avg_reward / total_valid_rewards if total_valid_rewards > 0 else 0.0


        # 4. Print or store the identification of the best performing policy
        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Best performing policy is {best_policy_name} with Avg Reward = {highest_avg_reward:.4f}")
        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Group Average Reward = {group_avg_reward:.4f}")


        # Log epoch metrics to Weights & Biases
        epoch_metrics["epoch/best_policy"] = best_policy_name
        epoch_metrics["epoch/highest_avg_reward"] = highest_avg_reward
        epoch_metrics["epoch/group_average_reward"] = group_avg_reward # Log group average reward

        # Log all collected epoch metrics
        if wandb.run is not None:
             wandb.log(epoch_metrics, step=epoch + 1) # Log all epoch metrics at once
        else:
             print("Weights & Biases not initialized. Skipping epoch metric logging.")


        # --- Implement Policy Update Phase ---
        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Starting policy update...")

        # --- Subtask Implementation: Modify policy update step ---
        # Use the average reward of the best policy as a baseline for all policies
        # Or use the individual policy's advantage relative to the best policy's average reward.
        # Let's use the individual policy's advantage calculated against its own baseline for simplicity in this step,
        # but consider the relative performance implicitly by comparing policies.
        # A more GRPO-like approach might use the best policy's average reward as a global baseline.
        # Let's modify the advantage calculation to use the best policy's average reward as the baseline.

        best_policy_avg_reward_tensor = torch.tensor(highest_avg_reward, dtype=torch.float32) # Get the best policy's avg reward as a tensor


        for policy_idx, policy in enumerate(policy_group):
            policy_name = f"policy_{policy_idx}"
            optimizer = optimizers[policy_idx] # Get the specific optimizer for this policy

            # Get collected data for the current policy
            policy_log_probs = torch.tensor(all_policy_log_probs[policy_name], dtype=torch.float32)
            policy_rewards = torch.tensor(all_policy_rewards[policy_name], dtype=torch.float32) # Use raw rewards

            # Calculate advantage relative to the best policy's average reward
            # Advantage = Reward - Best Policy's Average Reward
            policy_advantages_relative_to_best = policy_rewards - best_policy_avg_reward_tensor

            # Filter out samples where original reward was 0 (likely due to errors)
            # We still use the original rewards to determine which samples were valid.
            valid_indices = policy_rewards != 0
            if torch.sum(valid_indices) > 0:
                valid_log_probs = policy_log_probs[valid_indices]
                # Detach advantages to prevent gradients from flowing through the reward calculation
                valid_advantages = policy_advantages_relative_to_best[valid_indices].detach()


                # 3. Calculate the policy loss using collected log probabilities and relative advantages
                policy_loss = -torch.mean(valid_log_probs * valid_advantages)

                # 5. Perform optimizer.zero_grad() for the current policy's optimizer
                optimizer.zero_grad()

                # --- Implement basic "trust region" with gradient clipping ---
                # 6. Call policy_loss.backward() to compute gradients
                policy_loss.backward()

                # Apply gradient clipping
                torch.nn.utils.clip_grad_norm_(policy.parameters(), MAX_GRAD_NORM)


                # 7. Call optimizer.step() to update the current policy's parameters
                optimizer.step()

                # Log the policy loss for each policy after its update
                if wandb.run is not None:
                     wandb.log({
                         f"{policy_name}/policy_loss": policy_loss.item(),
                     }, step=epoch + 1) # Log policy loss per epoch per policy
                else:
                     print(f"Weights & Biases not initialized. Skipping policy loss logging for {policy_name}.")


                # print(f"    {policy_name}: Policy loss = {policy_loss.item():.4f}") # Too verbose
            else:
                # print(f"    {policy_name}: No valid samples/advantages for update in this epoch.") # Too verbose
                if wandb.run is not None:
                     wandb.log({
                         f"{policy_name}/policy_loss": 0.0, # Log 0 loss if no update
                     }, step=epoch + 1)
                else:
                     print(f"Weights & Biases not initialized. Skipping policy loss logging for {policy_name} (no update).")


        print(f"  Epoch {epoch+1}/{NUM_EPOCHS}: Policy update completed.")


    print("Training finished.")

else:
    print("Training skipped because OECD index was not loaded due to missing document.")

# Finish the Weights & Biases run
if wandb.run is not None:
    # Corrected check for finishing run
    wandb.finish()