# Imports

In [None]:
import os
import gspread 
import openai
import operator
import requests
import json
import numpy as np

import os.path
import string

from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError

from pydantic import BaseModel, Field, Extra
from typing import Type, TypedDict, Annotated, List

from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain.tools import BaseTool

from langgraph.graph import StateGraph, END, START
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage, ToolMessage
from langchain_google_genai import ChatGoogleGenerativeAI

# API

В этом разделе расмещены все функции и классы для операций с Google Spreadsheet листами через API

Подробнее в [документации](https://developers.google.com/sheets/api/guides/concepts)

In [None]:
def column_to_index(column):
    index = 0
    for i, char in enumerate(reversed(column)):
        index += (string.ascii_uppercase.index(char) + 1) * (26 ** i)
    return index - 1

def index_to_column(index):
    column = ""
    while index >= 0:
        index, remainder = divmod(index, 26)
        column = string.ascii_uppercase[remainder] + column
        index -= 1
    return column

def parse_range(sheetId, cell_range):
    start_cell, end_cell = cell_range.split(':')
    
    start_col = ''.join(filter(str.isalpha, start_cell))
    start_row = ''.join(filter(str.isdigit, start_cell))
    
    end_col = ''.join(filter(str.isalpha, end_cell))
    end_row = ''.join(filter(str.isdigit, end_cell))
    
    return {
        "sheetId": sheetId,
        "startRowIndex": int(start_row) - 1,
        "endRowIndex": int(end_row),
        "startColumnIndex": column_to_index(start_col),
        "endColumnIndex": column_to_index(end_col) + 1
    }
    
    
class GooogleSheetsApi:
    def __init__(self, credentials, user_data):
        SCOPES = ["https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive.file"]
        self.creds = None
        if os.path.exists(user_data):
            self.creds = Credentials.from_authorized_user_file(user_data, SCOPES)
        if not self.creds or not self.creds.valid:
            if self.creds and self.creds.expired and self.creds.refresh_token:
                self.creds.refresh(Request())
            else:
                flow = InstalledAppFlow.from_client_secrets_file(
                    credentials, SCOPES
                )
                self.creds = flow.run_local_server(port=0)
            with open(user_data, "w") as token:
                token.write(self.creds.to_json())
    
    
    def set_sheet_id(self, spreadsheet_id):
        self.spreadsheet_id = spreadsheet_id
    
    
    def create_spreadsheet(self, title):
        service = build("sheets", "v4", credentials=self.creds)
        spreadsheet = {"properties": {"title": title}}
        spreadsheet = (
            service.spreadsheets()
            .create(body=spreadsheet, fields="spreadsheetId")
            .execute()
        )
        return spreadsheet.get("spreadsheetId")
    
    
    def read_values(self, range_name='A1:AZ100000'):
        service = build("sheets", "v4", credentials=self.creds)
        result = (
            service.spreadsheets()
            .values()
            .get(spreadsheetId=self.spreadsheet_id, range=range_name)
            .execute()
        )
        rows = result.get("values", [])
        return rows


    def write_values(self, range_name, values):
        service = build("sheets",  "v4", credentials=self.creds)
        body = {"values": values}
        body = (
            service.spreadsheets()
             .values()
             .update(
                spreadsheetId=self.spreadsheet_id, 
                range=range_name, 
                valueInputOption="USER_ENTERED",
                body=body)
             .execute()
         )
        return body


    def autofill(self, range_name):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "autoFill": {
                    "useAlternateSeries": False,
                    "range": parse_range(0, range_name)
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def repeat_formula(self, formula, range_name):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "repeatCell": {
                    "range": parse_range(0, range_name),
                    "cell": {
                        "userEnteredValue": {
                            "formulaValue": formula
                        }
                    },
                    "fields": "userEnteredValue"
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def create_pivot_table(self, source_range, target_cell, rows=[], columns=[], values=[]):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "updateCells": {
                    "rows": {
                        "values": [
                            {
                                "pivotTable": {
                                    "source": parse_range(0, source_range),
                                    "rows": [
                                        {
                                            "sourceColumnOffset": row,
                                            "sortOrder": "ASCENDING",
                                            "showTotals": False,
                                        } for row in rows
                                    ],
                                    "columns": [
                                        {
                                            "sourceColumnOffset": col,
                                            "sortOrder": "ASCENDING",
                                            "showTotals": False,
                                        } for col in columns
                                    ],
                                    "values": [
                                        {
                                            "summarizeFunction": func,
                                            "sourceColumnOffset": val,
                                        } for val, func in values
                                    ],
                                    "valueLayout": "HORIZONTAL",
                                }
                            }
                        ]
                    },
                    "start": {
                        "sheetId": 0,
                        "rowIndex": int(''.join(filter(str.isdigit, target_cell))),
                        "columnIndex": column_to_index(''.join(filter(str.isalpha, target_cell))),
                    },
                    "fields": "pivotTable",
                }
            }
        )

        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response


    def create_chart(self, chart_type, target_cell, domain_range, chart_series, title="", botton_axis_title="", left_axis_title=""):
        service = build("sheets",  "v4", credentials=self.creds)
        requests = []
        requests.append(
            {
                "addChart": {
                    "chart": {
                    "spec": {
                        "title": title,
                        "basicChart": {
                        "chartType": chart_type, # LINE, COLUMN, AREA
                        "legendPosition": "BOTTOM_LEGEND",
                        "axis": [
                            {
                            "position": "BOTTOM_AXIS",
                            "title": botton_axis_title
                            },
                            {
                            "position": "LEFT_AXIS",
                            "title": left_axis_title
                            }
                        ],
                        "domains": [
                            {
                            "domain": {
                                "sourceRange": {
                                "sources": parse_range(0, domain_range)
                                }
                            }
                            }
                        ],
                        "series": [
                            {
                            "series": {
                                "sourceRange": {
                                "sources": [parse_range(0, serie)]
                                }
                            },
                            "targetAxis": "LEFT_AXIS"
                            } for serie in chart_series 
                        ],
                        }
                    },
                    "position": {
                        "overlayPosition": {
                        "anchorCell": {
                            "sheetId": 0,
                            "rowIndex": int(''.join(filter(str.isdigit, target_cell)))-1,
                            "columnIndex": column_to_index(''.join(filter(str.isalpha, target_cell))),
                        }
                        }
                    }
                    }
                }
            }
        )
        body = {"requests": requests}
        response = (
            service.spreadsheets()
            .batchUpdate(spreadsheetId=self.spreadsheet_id, body=body)
            .execute()
        )
        return response
    
    
    def get_metadata(self):
        service = build("sheets",  "v4", credentials=self.creds)
        response  =  (
            service.spreadsheets()
             .get(spreadsheetId=self.spreadsheet_id)
             .execute()
         )
        return response


    def get_errors(self):
        res = self.read_values("A1:V50")
        return [(index_to_column(col) + str(row+1), el) for row, r in enumerate(res) for col, el in enumerate(r) if len(el) > 0 and el[0] == '#']

# Tools

Определения функций (tools), которыми будет пользоваться агент

Tools основаны на абстракции **BaseTool** из gigachain

In [None]:

# ------------------------------------------------- Auxiliary functions  ----------------------------------------------------
def matrix_to_markdown(header, matrix):
    result = "| " + " | ".join(header) + " |\n"
    result += "| " + " | ".join(["---"] * len(header)) + " |\n"
    for row in matrix:
        result += "| " + " | ".join(map(str, row)) + " |\n"
    return result

def column_to_index(column):
    index = 0
    for i, char in enumerate(reversed(column)):
        index += (string.ascii_uppercase.index(char) + 1) * (26 ** i)
    return index - 1

def index_to_column(index):
    column = ""
    while index >= 0:
        index, remainder = divmod(index, 26)
        column = string.ascii_uppercase[remainder] + column
        index -= 1
    return column

def markdown_to_matrix(markdown):
    for _ in range(4):
        markdown = markdown.replace('\\n', '\n')
    rows = [row for row in markdown.split('\n') if row != '']
    headers = [title.strip() for title in rows[0].split('|') if title != '']
    delims = ''.join([title.strip() for title in rows[1].split('|') if title != ''])
    matrix = [headers]
    for row in rows[2:]:
        matrix.append([item.strip() for item in row.split('|') if item != ''])
    return matrix

def longest_leftmost_non_empty_subarray(lst):
    i = 0
    while i < len(lst) and lst[i] == '':
        i += 1
    result, indicies = [], []
    while i < len(lst) and lst[i] != '':
        result.append(lst[i])
        indicies.append(i)
        i += 1
    return result, indicies

def get_data_outline(gapi, max_rows=20):
    max_rows = 20
    values_list = gapi.read_values()
    if len(values_list) == 0:
        return "No data on the sheet", "No data on the sheet"
    
    first_row = 0
    while len(values_list[first_row]) == 0:
        first_row += 1
    values_list = values_list[first_row:]
    first_row += 1
    
    res, ind = longest_leftmost_non_empty_subarray(values_list[0])
    table = np.array(values_list).T[ind].T
    if table.shape[0] > max_rows+1:
        header = table[0].reshape(1, -1)
        table = table[1:]
        selected_indices = np.random.choice(table.shape[0], max_rows, replace=False)
        table = table[selected_indices, :]
        table = np.concatenate([header, table])
    table = table.tolist()
    return matrix_to_markdown(table[0], table[1:]), f"{chr(ord('A') + ind[0])}{first_row}:{chr(ord('A') + ind[-1])}{len(values_list)+first_row-1}"


def describe_pivot_in_cell(api, cell):
    res = api.get_metadata()['sheets'][0]['properties']['gridProperties']
    max_row = res['rowCount']
    max_col = index_to_column(res['columnCount'] - 1)
    data = api.read_values(f"{cell}:{max_col}{max_row}")
    row_number = len(data)
    col_number = max(len(row) for row in data)
    start_col = ''.join(filter(str.isalpha, cell))
    start_row = int(''.join(filter(str.isdigit, cell))) + 1
    end_col = index_to_column(column_to_index(start_col) + col_number - 1)
    end_row = start_row + row_number - 2
    return \
    f"""
    Pivot table was written to cell {cell}
    Pivot table main data is located in range {start_col}{start_row}:{end_col}{end_row}
    Pivot table has {row_number-2} rows values and {col_number-1} columns values
    """

# ------------------------------------------------- Descriptions -----------------------------------------------------------
    
class write_table_description(BaseModel):
    range_name: str = Field(description="The name of the range on this sheet in A1 notation.")
    markdown_table: str = Field(description="Markdown table that will be written into the range.")
    
class write_value_description(BaseModel):
    cell_id: str = Field(description="ID of the cell in A1 notation.")
    value: str = Field(description="Value or formula that will be written into the cell. Formula should start with symbol '='.")
    
class ask_user_general_question_description(BaseModel):
    question: str = Field(description="General question to the user. Should be a single sentence. Answer will be a text string.")
    
class ask_user_alternative_question_description(BaseModel):
    question: str = Field(description="Alternative question to the user. Should be a single sentence.")
    options: list = Field(description="Options for user to answer this question, should be a list. The number of answer options should not exceed 5.")
    
class create_pivot_table_description(BaseModel):    
    source_range: str = Field(description="The name of the source range of pivot table in A1 notation. Defines where data will be read from.")
    target_cell: str = Field(description="ID of the target cell in A1 notation. Pivot table will be written into this cell.")
    rows: list = Field(description=\
        """
        Column indices in the source range that will be rows in the pivot table. Numbering starts from 0. Should be list of non-negative integers.
        For example: [0, 2] means that columns with indices 1 and 3 will be rows in the pivot table.
        List can be empty, then there will be no rows in pivot table.
        """
    )
    columns: list = Field(description=\
        """
        Column indices in the source range that will be columns in the pivot table. Numbering starts from 0. Should be list of non-negative integers.
        For example: [1, 3] means that columns with indices 1 and 3 will be columns in the pivot table.
        List can be empty, then there will be no columns in pivot table.
        """
    )
    values: list = Field(description=\
        """
        Column indices in the source range that will be data values the pivot table. Numbering starts from 0. 
        Should be list of tuples: the first tuple element is column index, the second tuple element is aggregation function.
        Aggregation function can be one of these: "SUM", "COUNTA", "AVERAGE", "MAX", "MIN"
        For example: [(4, "SUM")] means that columns with index 4 will summed up in the pivot table.
        List can be empty, then no values will be written into the pivot table.
        """
    )

class draw_chart_description(BaseModel):
    chart_type: str = Field(description="Type of the chart, can be on of the following: LINE, COLUMN, AREA")
    target_cell: str = Field(description="ID of the target cell in A1 notation. Chart will be written into this cell.")
    domain_range: str = Field(description="Range name of chart domain in A1 notation (like K9:M123). This domain will be the X axis.")
    chart_series: list = Field(description=\
        """
        List of range names in A1 notation. Each element of this list defines range name of data for chart series.
        For example: ["I5:I19", "J5:J19"] means that there will be two series on the chart: with data from range "I5:I19" and "J5:J19".
        """
    )
    title: str = Field(description="Title of the chart")
    x_title: str = Field(description="Title of the X axis")
    y_title: str = Field(description="Title of the Y axis (or series)")
    
class repeat_formula_description(BaseModel):
    formula: str = Field(description="Formula to be repeated in range. Should start with '='")
    range_name: str = Field(description="The name of the range, where formula will be repeated.")

    
class autofill_constant_description(BaseModel):
    constant: str = Field(description="Constant value to be written in a range (not only one cell).")
    range_name: str = Field(description="The name of the range to be autofilled with constant value.")
    
    
class autofill_delta_description(BaseModel):
    first_value: str = Field(description="First value to be written in a range.")
    second_value: str = Field(description="Sencond value to be written in a range.")
    range_name: str = Field(description="The name of the range to be autofilled with constant difference between values.")
    
class get_stock_data_description(BaseModel):
    stock_id: str = Field(description="Identifier of a stock.")
    start_dt: str = Field(description="Start date in form yyyy-mm-dd")
    end_dt: str = Field(description="End date in form yyyy-mm-dd")
    cell_id: str = Field(description="ID of the target cell in A1 notation.")

# ------------------------------------------------- Tools -----------------------------------------------------------    

class write_table_tool(BaseTool, extra='allow'):
    name = "write_table"
    description = \
    """
    Call this function with a cell ID and a markdown table to write this table into the specified cell.
    Markdown table must have headers. Pass it into argument as a string.
    Formulas should begin with "=".
    """
    args_schema: Type[BaseModel] = write_table_description

    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, range_name, markdown_table):
        matrix = markdown_to_matrix(markdown_table)
        self.gapi.write_values(range_name, matrix)
        return f"The data from the table {markdown_table} has been written into the cell {range_name}"
    
    
class write_value_tool(BaseTool, extra='allow'):
    name = "write_value"
    description = \
    """
    Call this function with a cell ID and a value of formula to write this value or formula into only one specified cell.
    Only use formulas you have an access to. If formula has arguments, they should be separated by commas.
    Example of a formula: "=SUM(A1:A5)"
    """
    args_schema: Type[BaseModel] = write_value_description

    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, cell_id, value):
        self.gapi.write_values(cell_id, [[str(value).replace("'", "\"")]])
        return f"Value {value} has been written into the cell {cell_id}"

    
    
class ask_user_general_question_tool(BaseTool):
    name = "ask_user_general_question"
    description = \
    """
    Call this function with a general question to the user.
    Do this if you need extra information from user.
    """
    args_schema: Type[BaseModel] = ask_user_general_question_description

    def _run(self, question):
        print(question)
        answer = input()
        return answer
    
    
class ask_user_alternative_question_tool(BaseTool):
    name = "ask_user_alternative_question"
    description = \
    """
    Call this function with a alternative question to the user.
    Provide 5 or less option for user to chose from.
    """
    args_schema: Type[BaseModel] = ask_user_alternative_question_description

    def _run(self, question, options):
        print(question)
        print("Please chose one of the following options: (enter number)")
        print("\n".join([f"{num}) {option}" for num, option in enumerate(options)]))
        answer = input()
        return options[answer]
    
    
class create_pivot_table_tool(BaseTool, extra='allow'):
    name = "create_pivot_table"
    description = \
    """
    Call this function to create a pivot table with data on the sheet.
    """
    args_schema: Type[BaseModel] = create_pivot_table_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, source_range, target_cell, rows, columns, values):
        self.gapi.create_pivot_table(source_range, target_cell, rows=rows, columns=columns, values=values)
        decs = describe_pivot_in_cell(self.gapi, target_cell)
        return f"Created pivot table from data in {source_range} with rows {rows}, columns {columns} and values {values}\n" + decs
    
    
class draw_chart_tool(BaseTool, extra='allow'):
    name = "draw_chart"
    description = \
    """
    Call this function to create a line, column or area chart with data on the sheet.
    """
    args_schema: Type[BaseModel] = draw_chart_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, chart_type, target_cell, domain_range, chart_series, title, x_title, y_title):
        self.gapi.create_chart(chart_type, target_cell, domain_range, chart_series, title=title, botton_axis_title=x_title, left_axis_title=y_title)
        return f"Created {chart_type} chart from data in {chart_series} with domain in {domain_range}"
    
    
class repeat_formula_tool(BaseTool, extra='allow'):
    name = "repeat_formula"
    description = \
    """
    Call this function to repeat formula in specified range.
    The formula's range automatically increments for each row and column in the range, starting with the upper left cell. 
    For example, if cell B1 has the formula =FLOOR(A1*PI()), while cell D6 has the formula =FLOOR(C6*PI()).
    """
    args_schema: Type[BaseModel] = repeat_formula_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, formula, range_name):
        self.gapi.repeat_formula(formula, range_name)
        return f"Repeated formula {formula} in range {range_name}"
    
    
class autofill_constant_tool(BaseTool, extra='allow'):
    name = "autofill_constant"
    description = \
    """
    Call this function to autofill constant value in specified range.
    For example, if you need to write value "678" into all cells of range "G3:G8", call this function with constant "678" and range_name "G3:G8".
    """
    args_schema: Type[BaseModel] = autofill_constant_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, constant, range_name):
        self.gapi.write_values(range_name.split(':')[0], [[constant]])
        self.gapi.autofill(range_name)
        return f"Constant value {constant} was written in every cell of range {range_name}"
    
    
class autofill_delta_tool(BaseTool, extra='allow'):
    name = "autofill_delta"
    description = \
    """
    Call this function to autofill values in specified range with a constant difference.
    For example, if you need to write values "1 2 3 4 5 6" into range "G3:G8", call this function with 
    - fisrt_value: "1"
    - second_value: "2"
    - range_name "G3:G8".
    This will work because all elements in "1 2 3 4 5 6" have constant difference.
    """
    args_schema: Type[BaseModel] = autofill_delta_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, first_value, second_value, range_name):
        first_cell = range_name.split(':')[0]
        start_col = ''.join(filter(str.isalpha, first_cell))
        start_row = ''.join(filter(str.isdigit, first_cell))
        start_row = str(int(start_row)+1)
        second_cell = start_col + start_row
        self.gapi.write_values(first_cell, [[first_value]])
        self.gapi.write_values(second_cell, [[second_value]])
        self.gapi.autofill(range_name)
        return f"Values were written into range {range_name}"
    

class get_stock_data_tool(BaseTool, extra='allow'):
    name = "get_stock_data"
    description = \
    """
    Call this function to write stock data into specified cell.
    You have access to the following stock identifiers (stock_id):
    - "SBER" for SberBank or Sber
    - "YNDX" for Yandex 
    - "GAZP" for Gazprom
    Data will occupy two columns: one for dates, one for prices.
    """
    args_schema: Type[BaseModel] = get_stock_data_description
    
    def __init__(self, gapi, **data):
        super().__init__(**data)
        self.gapi = gapi

    def _run(self, stock_id, start_dt, end_dt, cell_id):
        res = requests.get(f'https://iss.moex.com/iss/engines/stock/markets/shares/securities/{stock_id}/candles.json?from={start_dt}&till={end_dt}&interval=24').json()
        prices = [tmp[4]/tmp[5] for tmp in res['candles']['data']]
        dates = [tmp[6][:10] for tmp in res['candles']['data']]
        matrix = [[date, price] for date, price in zip(dates, prices)]
        response = self.gapi.write_values(cell_id, matrix)
        result_range = response['updatedRange'].split('!')[1]
        return \
        f"""
        Stock data has been written to range: {result_range}
        First column contains dates, second column contains prices.
        Table does not contain header.
        """
        

# Agent

Основной код агента на gigagraph 

Подробно основные компоненты и логика работы рассмотрены [здесь](наша_статья)

In [None]:
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]
    
    
class GoogleSheetsAgent:
    def __init__(self, gapi, gsheet_tools, feedback_tools):
        self.model = None
        
        builder = StateGraph(AgentState)
        builder.add_node("Main annotator", self.annotator_node)
        builder.add_node("Annotation reflector", self.annotator_reflection_node)
        builder.add_node("Annotation caller", self.call_node)
        builder.add_node("Annotation tooler", self.tool_node)
        builder.add_node("Plan proxy", self.plan_proxy)
        builder.add_node("Usual planner", self.plan_node)
        builder.add_node("COT SC planner", self.plan_node_cot)
        builder.add_node("TOT planner", self.plan_node_tot)
        builder.add_node("TOT SC planner", self.plan_node_tot_sc)
        builder.add_node("Planner cells critic", self.plan_critic_cells_node)
        builder.add_node("Planner formulas critic", self.plan_critic_formulas_node)
        builder.add_node("Main executor", self.executor_node)
        builder.add_node("Main caller", self.call_node)
        builder.add_node("Main tooler", self.tool_node)
        builder.add_node("Error reflector", self.error_corrector_node)
        builder.add_node("Error caller", self.call_node)
        builder.add_node("Error tooler", self.tool_node)
        
        builder.add_edge("Annotation reflector", "Annotation caller")
        builder.add_edge("Annotation tooler", "Annotation caller")
        builder.add_edge("Usual planner", "Planner cells critic")
        builder.add_edge("COT SC planner", "Planner cells critic")
        builder.add_edge("TOT planner", "Planner cells critic")
        builder.add_edge("TOT SC planner", "Planner cells critic")
        builder.add_edge("Main executor", "Main caller")
        builder.add_edge("Main tooler", "Main caller")
        builder.add_edge("Error tooler", "Error caller")
        
        builder.add_conditional_edges(
            "Main caller", 
            self.exists_action, 
            {True: "Main tooler", False: "Error reflector"}
        )
        builder.add_conditional_edges(
            "Annotation caller", 
            self.exists_action, 
            {True: "Annotation tooler", False: "Plan proxy"}
        )
        builder.add_conditional_edges(
            START, 
            self.outline_node, 
            {True: "Main annotator", False: "Plan proxy"}
        )
        builder.add_conditional_edges(
            "Main annotator", 
            self.data_is_relation, 
            {True: "Annotation reflector", False: "Plan proxy"}
        )
        builder.add_conditional_edges(
            "Planner cells critic", 
            self.continue_plan_critic_cells, 
            {True: "Planner cells critic", False: "Planner formulas critic"}
        )
        builder.add_conditional_edges(
            "Planner formulas critic", 
            self.continue_plan_critic_formulas, 
            {True: "Planner formulas critic", False: "Main executor"}
        )
        builder.add_conditional_edges(
            "Plan proxy", 
            self.switch_planner, 
            {
                "Usual": "Usual planner", 
                "COT_SC": "COT SC planner",
                "TOT": "TOT planner",
                "TOT_SC" : "TOT SC planner"
            }
        )
        builder.add_conditional_edges(
            "Error reflector", 
            self.exists_errors, 
            {True: "Error caller", False: END}
        )
        builder.add_conditional_edges(
            "Error caller", 
            self.exists_action, 
            {True: "Error tooler", False: "Error reflector"}
        )
        
        self.graph = builder.compile()
        self.logs = ""
        
        self.gsheet_tools_list = gsheet_tools
        self.gsheet_tools_dict = {t.name: t for t in gsheet_tools}
        self.feedback_tools_list = feedback_tools
        self.feedback_tools_dict = {t.name: t for t in feedback_tools}
        
        self.gapi = gapi
        
        
    def get_logs(self):
        return self.logs
        
        
    def invoke(self, query, prompts, 
               model_name = 'gpt-4o',
               temperature = 0.0,
               planner_mode="Usual", 
               plan_critic_formulas_max_tries=2, 
               plan_critic_cells_max_tries=2, 
               plan_sc_number=5, 
               tot_number=5,
               tot_max_rounds = 20,
               tot_mode="one_call",
               errors_correction_max_tries=2
        ):
        state = \
        {
            "messages" : []
        }
        self.query = query
        self.logs = ""
        self.model = ChatOpenAI(model_name=model_name, temperature=temperature)
        self.prompts = prompts
        self.input_tokens = 0
        self.output_tokens = 0
        self.plan_critic_formulas_tries = plan_critic_formulas_max_tries
        self.plan_critic_cells_tries = plan_critic_cells_max_tries
        self.plan_sc_number = plan_sc_number
        self.planner_mode = planner_mode
        self.tot_number = tot_number
        self.tot_max_rounds = tot_max_rounds
        self.tot_mode = tot_mode
        self.errors_correction_tries = errors_correction_max_tries+1
        
        result = self.graph.invoke(state)
        result = result.get('messages')[-1].content
        
        logs = self.get_logs()
        logs = f"Model: {model_name}\nTemperature: {temperature}\nQuery: {query}\n\n" + logs
        return result, logs
    
    
    def print_logs(self, log):
        self.logs += str(log) + '\n'
        
    def print_header(self, header):
        self.logs += f"\n{header:-^80}\n"
        
        
    def outline_node(self, state: AgentState):
        self.print_logs("Data outlining: getting data...")
        sample, data_range = get_data_outline(self.gapi)
        self.print_logs("Data outlining: data obtained")
        self.sample = sample
        self.data_range = data_range
        return self.sample != "No data on the sheet"
        
        
    def annotator_node(self, state: AgentState):
        messages = [
            SystemMessage(content=self.prompts['annotation_prompt']), 
            HumanMessage(content=self.prompts['annotation_task'].format(query=self.query, sample=self.sample))
        ]
        response = self.model.invoke(messages)
        
        self.print_header("ANNOTATION PROMPT")   
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("ANNOTATION RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']
        self.output_tokens += response.usage_metadata['output_tokens']
        
        return {"messages": [response]}
        
        
    def annotator_reflection_node(self, state: AgentState):
        outline = state['messages'][-1].content
        messages = [
            SystemMessage(content=self.prompts['annotation_reflector']), 
            HumanMessage(content=outline)
        ]
        self.print_header("ANNOTATION REFLECTOR PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        self.tools_list = self.feedback_tools_list
        self.tools_dict = self.feedback_tools_dict
        return {"messages": messages}
        
    def plan_proxy(self, state):
        pass
        
    def plan_node(self, state: AgentState):
        self.outline = state['messages'][-1].content if len(state['messages']) > 0 else "No data on the sheet"
        messages = [
            SystemMessage(content=self.prompts['planner_prompt']), 
            HumanMessage(content=self.prompts['planner_task'].format(query=self.query, outline=self.outline, cells=self.data_range))
        ]
        response = self.model.invoke(messages)
        
        self.print_header("PLAN PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("PLAN RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']
        self.output_tokens += response.usage_metadata['output_tokens']
        json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
        self.plan = "\n".join([f"{n+1}) {el}" for n, el in enumerate(json_res.get('plan'))])
        

    def plan_node_cot(self, state: AgentState):
        self.outline = state['messages'][-1].content if len(state['messages']) > 0 else "No data on the sheet"
        plans = []
        messages = [
            SystemMessage(content=self.prompts['planner_prompt']), 
            HumanMessage(content=self.prompts['planner_task'].format(query=self.query, outline=self.outline, cells=self.data_range))
        ]
        for i in range(self.plan_sc_number):
            response = self.model.invoke(messages)
            json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
            plan = "\n".join([f"{n+1}) {el}" for n, el in enumerate(json_res.get('plan'))])
            plans.append(plan)
            self.input_tokens += response.usage_metadata['input_tokens']
            self.output_tokens += response.usage_metadata['output_tokens']
        self.print_header("PLAN PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        plans_prompt = "\n\n".join([f"Plan {n}:\n{plan}" for n, plan in enumerate(plans)])
        messages = [
            SystemMessage(content=self.prompts['planner_cot_voter_prompt']), 
            HumanMessage(content=plans_prompt)
        ]
        response = self.model.invoke(messages)
        self.print_header("PLAN COT PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("PLAN COT RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']    
        self.output_tokens += response.usage_metadata['output_tokens']
        json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
        self.plan = "\n".join([f"{n+1}) {el}" for n, el in enumerate(json_res.get('plan'))])
        
        
    def plan_node_tot(self, state: AgentState):
        self.outline = state['messages'][-1].content if len(state['messages']) > 0 else "No data on the sheet"
        self.plan = self.generate_tot_plan(mode=self.tot_mode)
        
        
    def plan_node_tot_sc(self, state: AgentState):
        self.outline = state['messages'][-1].content if len(state['messages']) > 0 else "No data on the sheet"
        plans = []
        for i in range(self.plan_sc_number):
            plans.append(self.generate_tot_plan(mode=self.tot_mode))
        plans_prompt = "\n\n".join([f"Plan {n}:\n{plan}" for n, plan in enumerate(plans)])
        messages = [
            SystemMessage(content=self.prompts['planner_cot_voter_prompt']), 
            HumanMessage(content=plans_prompt)
        ]
        response = self.model.invoke(messages)
        self.print_header("PLAN COT PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("PLAN COT RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']    
        self.output_tokens += response.usage_metadata['output_tokens']
        json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
        self.plan = "\n".join([f"{n+1}) {el}" for n, el in enumerate(json_res.get('plan'))])
    
    
    def executor_node(self, state):
        messages = [
            SystemMessage(content=self.prompts['executor_prompt']), 
            HumanMessage(content=self.plan)
        ]
        state['messages'] = []
        self.print_header("TASK PROMPT")
        for messge in messages:
            self.print_logs(messge.content)
        self.tools_list = self.gsheet_tools_list
        self.tools_dict = self.gsheet_tools_dict
        return {"messages": messages}
    

    def plan_critic_cells_node(self, state: AgentState):
        messages = [
            SystemMessage(content=self.prompts['planner_critic_cells_prompt']), 
            HumanMessage(content=self.prompts['planner_critic_prompt'].format(
                query=self.query, 
                outline=self.outline, 
                cells=self.data_range, 
                plan=self.plan
            ))
        ]
        response = self.model.invoke(messages)
        
        self.print_header("PLAN CRITIC (CELLS) PROMPT")   
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("PLAN CRITIC (CELLS) RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']
        self.output_tokens += response.usage_metadata['output_tokens']
        
        json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
        
        self.change_str = json_res.get("changed") == "True"
        self.plan_critic_cells_tries -= 1
        if self.change_str:
            self.plan = "\n".join(json_res.get('plan'))
    
    
    def plan_critic_formulas_node(self, state: AgentState):
        messages = [
            SystemMessage(content=self.prompts['planner_critic_formulas_prompt']), 
            HumanMessage(content=self.prompts['planner_critic_prompt'].format(
                query=self.query, 
                outline=self.outline, 
                cells=self.data_range, 
                plan=self.plan
            ))
        ]
        response = self.model.invoke(messages)
        
        self.print_header("PLAN CRITIC (FORMULAS) PROMPT")   
        for messge in messages:
            self.print_logs(messge.content)
        self.print_header("PLAN CRITIC (FORMULAS) RESPONSE")
        self.print_logs(response.content) 
        self.input_tokens += response.usage_metadata['input_tokens']
        self.output_tokens += response.usage_metadata['output_tokens']
        
        json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
        
        self.change_str = json_res.get("changed") == "True"
        self.plan_critic_formulas_tries -= 1
        if self.change_str:
            self.plan = "\n".join(json_res.get('plan'))
    
    
    def call_node(self, state: AgentState, config):
        messages = state['messages']
        model_with_tools = self.model.bind_tools(self.tools_list)
        response = model_with_tools.invoke(messages)

        self.print_header("TASK AGENT RESPONSE")
        self.print_logs("Content: " + response.content) 
        self.print_logs("Tool calling: " + str(len(response.tool_calls) > 0))
        self.input_tokens += response.usage_metadata['input_tokens']
        self.output_tokens += response.usage_metadata['output_tokens']
        
        return {"messages": [response]}


    def tool_node(self, state: AgentState, config):
        tool_calls = state['messages'][-1].tool_calls
        results = []

        self.print_header("TOOL CALLING")
        
        for tool in tool_calls:
            self.print_logs(f"Tool: {tool['name']}")
            if not tool['name'] in self.tools_dict:      
                self.print_logs("\nBad tool name!")
                result = "Bad tool name, retry!"  
            else:
                self.print_logs('Args: ' + str(tool['args']))
                result = self.tools_dict[tool['name']].invoke(tool['args'])
                self.print_logs('Result: ' + str(result))
            results.append(ToolMessage(tool_call_id=tool['id'], name=tool['name'], content=str(result)))
        return {'messages': results}
    
    
    def error_corrector_node(self, state):
        self.errors = "\n".join([str(a) for a in self.gapi.get_errors()])
        messages = state['messages'] + [
            SystemMessage(content=self.prompts['error_corrector_prompt']), 
            HumanMessage(content=self.prompts['error_corrector_task'].format(
                errors=self.errors
            ))
        ]
        state['messages'] = []
        self.print_header("ERROR CORRECTION PROMPT")
        for messge in messages[-2:]:
            self.print_logs(messge.content)
        self.tools_list = self.gsheet_tools_list
        self.tools_dict = self.gsheet_tools_dict
        return {"messages": messages}


    def exists_action(self, state: AgentState):
        result = state['messages'][-1]
        return len(result.tool_calls) > 0
    
    def exists_errors(self, state: AgentState):
        self.errors_correction_tries -= 1
        return len(self.gapi.get_errors()) > 0 and self.errors_correction_tries > 0
    
    def data_is_relation(self, state: AgentState):
        result = state['messages'][-1]
        return result.content.split('\n')[0] == 'relational'
    
    def switch_planner(self, state):
        return self.planner_mode
    
    def continue_plan_critic_cells(self, state: AgentState):
        return self.change_str and self.plan_critic_cells_tries > 0

    def continue_plan_critic_formulas(self, state):
        return self.change_str and self.plan_critic_formulas_tries > 0
    
    def tot_plan_generate_options(self, mode, current_plan_str):
        match mode:
            case 'one_call':
                messages = [
                    SystemMessage(content=self.prompts['planner_tot_generator_prompt'].format(
                        tot_number=self.tot_number
                    )), 
                    HumanMessage(content=self.prompts['planner_tot_generator_task'].format(
                        query=self.query, 
                        outline=self.outline, 
                        cells=self.data_range, 
                        plan=current_plan_str
                    ))
                ]
                response = self.model.invoke(messages)
                self.input_tokens += response.usage_metadata['input_tokens']
                self.output_tokens += response.usage_metadata['output_tokens']
                self.print_logs(response.content)
                options = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", "")).get('next_step_options')
            case 'many_calls':
                options = []
                messages = [
                    SystemMessage(content=self.prompts['planner_tot_generator_prompt'].format(
                        tot_number=1
                    )), 
                    HumanMessage(content=self.prompts['planner_tot_generator_task'].format(
                        query=self.query, 
                        outline=self.outline, 
                        cells=self.data_range, 
                        plan=current_plan_str
                    ))
                ]
                for i in range(self.tot_number):
                    response = self.model.invoke(messages)
                    self.input_tokens += response.usage_metadata['input_tokens']
                    self.output_tokens += response.usage_metadata['output_tokens']
                    self.print_logs(response.content)
                    opt = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", "")).get('next_step_options')[0]
                    options.append(opt)
        return options

            
    def generate_tot_plan(self, mode):
        current_plan = []
        current_plan_str = "Plan is empty"
        terminate = False
        rounds = 0
        while not terminate and rounds < self.tot_max_rounds:
            self.print_header(f"TOT PLAN (STEP {rounds})") 
            options = self.tot_plan_generate_options(mode, current_plan_str)
            options_str = "\n".join([f"{n+1}) {el}" for n, el in enumerate(options)])
            messages = [
                SystemMessage(content=self.prompts['planner_tot_discriminator_prompt']), 
                HumanMessage(content=self.prompts['planner_tot_discriminator_task'].format(
                    query=self.query, 
                    outline=self.outline, 
                    cells=self.data_range, 
                    plan=current_plan_str,
                    options=options_str
                ))
            ]
            response = self.model.invoke(messages)
            self.print_logs(messages[1].content)
            self.print_logs(response.content)
            self.input_tokens += response.usage_metadata['input_tokens']
            self.output_tokens += response.usage_metadata['output_tokens']
            json_res = json.loads(response.content.split("Result:\n")[-1].replace("```", "").replace("json", ""))
            next_step = options[int(json_res.get('next_step'))-1]
            current_plan.append(next_step)
            current_plan_str = "\n".join([f"{n+1}) {el}" for n, el in enumerate(current_plan)])
            terminate = json_res.get('terminate') == "True"
            rounds += 1
        return current_plan_str


# Prompts

Весь набор необходимых промптов для агента

In [None]:
with open('prompts.json', 'r') as f:
    prompts = json.load(f)
    
with open('formulas.txt', 'r') as f:
    formulas = f.read()
    
annotation_prompt = \
"""
You are a professional data annotator
You are given a markdown table with columns names and sample data
Also you are given a user query
First, you should find out if this table is in a usual relation database form or in a form of dictionary
Example of a dictionary:
| cats | 5 |
| --- | --- |
| dogs | 4 |
| all | 9 |

If this table is in a dictionary form, then you should return a json dictionary of this table, for example:
{
    "cats" : 5,
    "dogs" : 4,
    "all" : 9,
}

If table is in a relational database form, then you should do following steps:
Using user query and sample in markdown table describe all data columns in the following json form:
{
    "name" : name of column,
    "description" : brief description of this column, try to provide meaning of this column for the user
    "type" : one of the following: number, date, text, empty
}.
If you don't understand meaning of the columns, write unknown in description.

At the end give brief (1-2 sentences) annotation about entire table in the following json form:
"table_description" : description of this table
Make sure that this description useful for the user.

Example:
{
    "columns": [
        {
            "name": "column1",
            "description": "Number of cats",
            "type": "number"
        },
        {
            "name": "column2",
            "description": "Number of dogs",
            "type": "number"
        }
    ],
    "table_description": "Table with information about cats and dogs"
}

Yout first word should be either "dictionary" (if table is in a dictionary form) or "relational" (if table is in a relational database form)
Then return only json dictionary (if table is in a dictionary form) or json description (if table is in a relational database form)
If sample is empty, them return empty json!
Make sure your output after first word can be read by json.loads!
"""

annotation_reflector = \
"""
You are a professional data annotator
You are given a prepared annotation of a table in the json form
Example
{
    "columns": [
        {
            "name": "column1",
            "description": "Number of cats",
            "type": "number"
        },
        {
            "name": "column2",
            "description": "Number of dogs",
            "type": "number"
        }
    ],
    "table_description": "Table with information about cats and dogs"
}

If you can't clearly understand meaning of the column or field with its description, ask the user for clarification
For example, you can ask "What does the column mean?"
Use this option if and only if you have no idea of the column meaning, do it only as a last resort!
If everything is clear, print final annotation
Make sure youe output can be read by json.loads!
"""

planner_prompt = \
"""
You are an expert data analyst who is tasked with writing a detailed plan for data analysis task in Gooogle Sheets
Your plan is step-by-step. It should be a plain list of actions.

You can do these actions:
- Get stock market data for a specific company and time interval and write to specified cell (result table will have two columns!)
- Write a markdown table to specified cell (always use "=" in front of formulas of values)
- Write a formula or a value to one specified cell 
- Create a pivot table from existing data
- Create a simple chart of one of the following types: line, column, area
- Repeat a formula to a range of cells with automatic increment of cell address
- Autofill blank range with a constant value
- Autofill blank range with a data series with constant difference (for example, 1 2 3 4 5 ...)

For a pivot table provide, what data should be used and what columns of this data will go to rows, columns and values of pivot table
For a chart provide, what data should be used and what range will be domain (X axis) and what ranges will be data series 
If you write a formula of a cell to markdown table always use symbol "=" in front of it!

Do not give any examples, only plan for the task, be precise. Think step by step.
You can use only steps written above!
You will be given user query (task) and information about data in the sheet
Given information above, write a detailed plan for a following query
Fow a good plan I will pay 100000 dollars.
Do not give any examples!

Your answer should be in the following form:
Thoughts: *your thoughts, think step by step*
Result:
{
    "plan": result plan (should be a list of strings)
}
Make sure that output after "Result" can be parsed with json.loads()!
"""

planner_cot_voter_prompt = \
"""
You are an senior data analyst who is tasked with chosing a best plan for data analysis task in Gooogle Sheets
You will be given user query (task), information about data in the sheet and several plans for this query.
You should find the best plan and print it.

Your answer should be in the following form:
Thoughts: *your thoughts, think step by step*
Result:
{
    "plan": best plan (should be a list of strings)
}
Make sure that output after "Result" can be parsed with json.loads()!
"""

planner_tot_generator_prompt = \
"""
You are an expert data analyst who is tasked with writing a detailed plan for data analysis task in Gooogle Sheets
You can do these actions:
- Get stock market data for a specific company and time interval and write to specified cell (result table will have to columns!)
- Write a markdown table to specified cell (always use "=" in front of formulas of values)
- Write a formula or a value to one specified cell 
- Create a pivot table from existing data
- Create a simple chart of one of the following types: line, column, area
- Repeat a formula to a range of cells with automatic increment of cell address
- Autofill blank range with a constant value
- Autofill blank range with a data series with constant difference (for example, 1 2 3 4 5 ...)

For a pivot table provide, what data should be used and what columns of this data will go to rows, columns and values of pivot table
For a chart provide, what data should be used and what range will be domain (X axis) and what ranges will be data series 
If you write a formula of a cell to markdown table always use symbol "=" in front of it!

You will be given user query (task), information about data in the sheet and current plan
You should produce {tot_number} possible options for the next step of the plan.
Do not give any examples, only {tot_number} possible options for the next step for the task, be precise. Think step by step.
If you think, that plan is finished, print "End" as option.
You can use only steps written above!
For good options I will pay you 100000 dollars.

Your answer should be in the following form:
Thoughts: *your thoughts, think step by step*
Result:
{{
    "next_step_options": {tot_number} options for the next step of the plan (should be a list of strings)
}}
Make sure that output after "Result" can be parsed with json.loads()!
"""

planner_tot_discriminator_prompt = \
"""
You are an expert data analyst who is tasked with writing a detailed plan for data analysis task in Gooogle Sheets
You will be given user query (task), information about data in the sheet, current plan and options for the next step.
You should choose the best option for the next step of the plan.
Do not repeat plan steps and do not add redudant steps!
If this step should be terminating (to finish the plan), set "terminate" to "True".
Also set "terminate" to "True" if you chose "End" option.
For good choice I will pay you 100000 dollars.

Your answer should be in the following form:
Thoughts: *your thoughts, think step by step*
Result:
{
    "next_step": number of chosen option (should be a number),
    "terminate": "True" or "False" (if plan should be finished with this step)
}
Make sure that output after "Result" can be parsed with json.loads()!
"""

planner_critic_cells_prompt = \
"""
You are a senior data analyst who is tasked with validating plan for data analysis task in Gooogle Sheets
You should check if all cells in a given plan are correct and change them if needed.

You will be given user query (task), information about data in the sheet and a plan for this query
Ensure that these requirements are met:
- All formulas have correct adresses of data ranges in arguments
- Result data ranges on the sheets do not overlap
- All tables and formulas on the sheet are displayes in a human convinient way.

You can only change cells addresses in formulas and tables, nothing more!

Your answer should be in the following form:
Thoughts: *your thoughts, think step by step*
Result:
{
    "changed" : "True" or "False",
    "changes_list" : List of changes (should be list of strings),
    "plan" : Modified or original plan (should be list of strings)
}
Make sure that output after "Result" can be parsed with json.loads()!
"""

planner_critic_formulas_prompt = \
f"""
You are a senior data analyst who is tasked with validating plan for data analysis task in Gooogle Sheets
Ypu should check if all formulas have correct names, correct number of argumnets and correct types of arguments.

You will be given user query (task), information about data in the sheet and a plan for this query
This is the complete list of formulas you can use:
{formulas}

Ensure that plan use only these formulas with correct names, correct number of arguments and correct types of arguments.
Think step by step. Fow a good solution I will pay 100000 dollars.
Carefully study list of formulas and check, that all formulas in plan are correct!
You can only change formulas names or arguments, nothing more!

Your answer should be in the following form:
Thoughts: *your thoughts*
Result:
{{
    "changed" : "True" or "False",
    "changes_list" : List of changes (empty if nothing changed),
    "plan" : Modified plan (or original plan if nothing changed)
}}
Make sure that output after "Result" can be parsed with json.loads()!
"""

error_corrector_prompt = \
"""
You are a senior data analyst who is tasked with correction errors in Gooogle Sheets task.
Your coworker has created and executed a plan for user task, but made some mistakes. You should correct these mistakes. 
You will be given a list of errors.
Correct these errors.
"""

executor_prompt = \
"""
You are helpful assistant who can work with markdown tables and use Google Sheets API to read and write data
You have a bunch of functions you can call to iteract with Google Sheets API.

Do all steps from the following plan:
"""

annotation_task = \
"""
Query: {query}
Sample: 
{sample}
"""

planner_task = \
"""
User query: {query}
Data outline: {outline}
Data is located in cells {cells}
"""

planner_critic_prompt = \
"""
User query: {query}
Data outline: {outline}
Data is located in cells {cells}
Plan: 
{plan}
"""

planner_tot_generator_task = \
"""
User query: {query}
Data outline: {outline}
Data is located in cells {cells}

Current plan: 
{plan}
"""

planner_tot_discriminator_task = \
"""
User query: {query}
Data outline: {outline}
Data is located in cells {cells}

Current plan: 
{plan}

Possible options:
{options}
"""

error_corrector_task = \
"""
List of errors in form (cell_id, error_name):
{errors}
"""

prompts['annotation_prompt'] = annotation_prompt
prompts['planner_prompt'] = planner_prompt
prompts['executor_prompt'] = executor_prompt
prompts['annotation_reflector'] = annotation_reflector
prompts['annotation_task'] = annotation_task
prompts['planner_task'] = planner_task
prompts['planner_critic_cells_prompt'] = planner_critic_cells_prompt
prompts['planner_critic_formulas_prompt'] = planner_critic_formulas_prompt
prompts['planner_critic_prompt'] = planner_critic_prompt
prompts['planner_cot_voter_prompt'] = planner_cot_voter_prompt
prompts['planner_tot_generator_prompt'] = planner_tot_generator_prompt
prompts['planner_tot_discriminator_prompt'] = planner_tot_discriminator_prompt
prompts['planner_tot_generator_task'] = planner_tot_generator_task
prompts['planner_tot_discriminator_task'] = planner_tot_discriminator_task
prompts['error_corrector_prompt'] = error_corrector_prompt
prompts['error_corrector_task'] = error_corrector_task

with open('prompts.json', 'w') as f:
    json.dump(prompts, f)

# Execution

Запуск агента

Предварительно необходимо создать директорию "../creds", куда положить файл "credentials.json" из личного кабинета Google API

Файл "authorized_user.json" создастся автоматически

Подробнее [в документации](https://developers.google.com/workspace/guides/get-started)


In [None]:
gapi = GooogleSheetsApi(credentials=f"../creds/credentials.json", user_data=f"../creds/authorized_user.json")

gsheet_tools = \
[
    write_table_tool(gapi), 
    write_value_tool(gapi), 
    create_pivot_table_tool(gapi),
    draw_chart_tool(gapi),
    repeat_formula_tool(gapi),
    autofill_constant_tool(gapi),
    autofill_delta_tool(gapi),
    get_stock_data_tool(gapi),
]
feedback_tools = [ask_user_general_question_tool()]

agent = GoogleSheetsAgent(gapi, gsheet_tools, feedback_tools)

with open('prompts.json', 'r') as f:
    prompts = json.load(f)

Для запуска агента нужно указать:
- sheet_id: id листа (находится в url)
- query: запрос в свободной текстовой форме

Логи выполнения можно смотреть в "agent.logs"

In [None]:
showcase = \
('sheet_id', # testcase_1
"""
query
"""),

In [None]:
gapi.set_sheet_id(showcase[0])
agent_response, agent_logs = agent.invoke(
    query=showcase[1], 
    prompts=prompts,
    model_name="gpt-4o", # "gpt-4o", "gpt-4o-mini"
    temperature=0.00,
    planner_mode="COT_SC", # Usual, COT_SC, TOT, TOT_SC
    plan_critic_formulas_max_tries=2, 
    plan_critic_cells_max_tries=2, 
    plan_sc_number=3, 
    tot_number=3,
    tot_max_rounds = 20,
    tot_mode="many_calls", # "one_call", "many_calls"
    errors_correction_max_tries=1
)
print(f"""
Overall input tokens: {agent.input_tokens}
Overall output tokens: {agent.output_tokens}
Overall total tokens: {agent.input_tokens + agent.output_tokens}
Estimate cost for GPT4o: {(agent.input_tokens*5 + agent.output_tokens*15)/10**6}$
""")

In [None]:
print(agent.logs)

# Formuals

Список формул для промптов

Нужно сохранить как "formulas.txt"

In [None]:
AVEDEV(value1, [value2, ...])	Calculates the average of the magnitudes of deviations of data from a dataset's mean.
AVERAGE(value1, [value2, ...])	Returns the numerical average value in a dataset, ignoring text.
AVERAGE.WEIGHTED(values, weights, [additional values], [additional weights])	Finds the weighted average of a set of values, given the values and the corresponding weights..
AVERAGEA(value1, [value2, ...])	Returns the numerical average value in a dataset.
AVERAGEIF(criteria_range, criterion, [average_range])	Returns the average of a range depending on criteria.
AVERAGEIFS(average_range, criteria_range1, criterion1, [criteria_range2, criterion2, ...])	Returns the average of a range depending on multiple criteria.
BETA.DIST(value, alpha, beta, cumulative, lower_bound, upper_bound)	Returns the probability of a given value as defined by the beta distribution function..
BETA.INV(probability, alpha, beta, lower_bound, upper_bound)	Returns the value of the inverse beta distribution function for a given probability.. 
BETADIST(value, alpha, beta, lower_bound, upper_bound)	See BETA.DIST.
BETAINV(probability, alpha, beta, lower_bound, upper_bound)	 See BETA.INV 
BINOM.DIST(num_successes, num_trials, prob_success, cumulative)	See BINOMDIST 
BINOM.INV(num_trials, prob_success, target_prob)	See CRITBINOM
BINOMDIST(num_successes, num_trials, prob_success, cumulative)	Calculates the probability of drawing a certain number of successes (or a maximum number of successes) in a certain number of tries given a population of a certain size containing a certain number of successes, with replacement of draws.
CHIDIST(x, degrees_freedom)	Calculates the right-tailed chi-squared distribution, often used in hypothesis testing.
CHIINV(probability, degrees_freedom)	Calculates the inverse of the right-tailed chi-squared distribution.
CHISQ.DIST(x, degrees_freedom, cumulative)	Calculates the left-tailed chi-squared distribution, often used in hypothesis testing.
CHISQ.DIST.RT(x, degrees_freedom)	Calculates the right-tailed chi-squared distribution, which is commonly used in hypothesis testing.
CHISQ.INV(probability, degrees_freedom)	Calculates the inverse of the left-tailed chi-squared distribution.
CHISQ.INV.RT(probability, degrees_freedom)	Calculates the inverse of the right-tailed chi-squared distribution.
CHISQ.TEST(observed_range, expected_range)	See CHITEST
CHITEST(observed_range, expected_range)	Returns the probability associated with a Pearson’s chi-squared test on the two ranges of data. Determines the likelihood that the observed categorical data is drawn from an expected distribution.
CONFIDENCE(alpha, standard_deviation, pop_size)	See CONFIDENCE.NORM
CONFIDENCE.NORM(alpha, standard_deviation, pop_size)	Calculates the width of half the confidence interval for a normal distribution..
CONFIDENCE.T(alpha, standard_deviation, size)	Calculates the width of half the confidence interval for a Student’s t-distribution..
CORREL(data_y, data_x)	Calculates r, the Pearson product-moment correlation coefficient of a dataset.
COUNT(value1, [value2, ...])	Returns a count of the number of numeric values in a dataset.
COUNTA(value1, [value2, ...])	Returns a count of the number of values in a dataset.
COVAR(data_y, data_x)	Calculates the covariance of a dataset.
COVARIANCE.P(data_y, data_x)	See COVAR
COVARIANCE.S(data_y, data_x)	Calculates the covariance of a dataset, where the dataset is a sample of the total population..
CRITBINOM(num_trials, prob_success, target_prob)	Calculates the smallest value for which the cumulative binomial distribution is greater than or equal to a specified criteria.
DEVSQ(value1, value2)	Calculates the sum of squares of deviations based on a sample.
EXPON.DIST(x, LAMBDA, cumulative)	Returns the value of the exponential distribution function with a specified LAMBDA at a specified value.. 
EXPONDIST(x, LAMBDA, cumulative)	See EXPON.DIST
FDIST(x, degrees_freedom1, degrees_freedom2)	See F.DIST.RT.
FINV(probability, degrees_freedom1, degrees_freedom2)	See F.INV.RT
FISHER(value)	Returns the Fisher transformation of a specified value.
FISHERINV(value)	Returns the inverse Fisher transformation of a specified value.
FORECAST(x, data_y, data_x)	Calculates the expected y-value for a specified x based on a linear regression of a dataset.
FORECAST.LINEAR(x, data_y, data_x)	See FORECAST 
FTEST(range1, range2)	Returns the probability associated with an F-test for equality of variances. Determines whether two samples are likely to have come from populations with the same variance.
GAMMA(number)	Returns the Gamma function evaluated at the specified value..
GAMMA.DIST(x, alpha, beta, cumulative)	Calculates the gamma distribution, a two-parameter continuous probability distribution.
GAMMA.INV(probability, alpha, beta)	The GAMMA.INV function returns the value of the inverse gamma cumulative distribution function for the specified probability and alpha and beta parameters..
GAMMADIST(x, alpha, beta, cumulative)	See GAMMA.DIST 
GAMMAINV(probability, alpha, beta)	See GAMMA.INV.
GAUSS(z)	The GAUSS function returns the probability that a random variable, drawn from a normal distribution, will be between the mean and z standard deviations above (or below) the mean..
GEOMEAN(value1, value2)	Calculates the geometric mean of a dataset.
HARMEAN(value1, value2)	Calculates the harmonic mean of a dataset.
HYPGEOM.DIST(num_successes, num_draws, successes_in_pop, pop_size)	See HYPGEOMDIST 
HYPGEOMDIST(num_successes, num_draws, successes_in_pop, pop_size)	 Calculates the probability of drawing a certain number of successes in a certain number of tries given a population of a certain size containing a certain number of successes, without replacement of draws.
INTERCEPT(data_y, data_x)	Calculates the y-value at which the line resulting from linear regression of a dataset will intersect the y-axis (x=0).
KURT(value1, value2)	Calculates the kurtosis of a dataset, which describes the shape, and in particular the "peakedness" of that dataset.
LARGE(data, n)	Returns the nth largest element from a data set, where n is user-defined.
LOGINV(x, mean, standard_deviation)	Returns the value of the inverse log-normal cumulative distribution with given mean and standard deviation at a specified value.
LOGNORM.DIST(x, mean, standard_deviation)	See LOGNORMDIST
LOGNORM.INV(x, mean, standard_deviation)	See LOGINV
LOGNORMDIST(x, mean, standard_deviation)	Returns the value of the log-normal cumulative distribution with given mean and standard deviation at a specified value.
MARGINOFERROR(range, confidence)	Calculates the amount of random sampling error given a range of values and a confidence level.
MAX(value1, [value2, ...])	Returns the maximum value in a numeric dataset.
MAXA(value1, value2)	Returns the maximum numeric value in a dataset.
MAXIFS(range, criteria_range1, criterion1, [criteria_range2, criterion2], …)	Returns the maximum value in a range of cells, filtered by a set of criteria..
MEDIAN(value1, [value2, ...])	Returns the median value in a numeric dataset.
MIN(value1, [value2, ...])	Returns the minimum value in a numeric dataset.
MINA(value1, value2)	Returns the minimum numeric value in a dataset.
MINIFS(range, criteria_range1, criterion1, [criteria_range2, criterion2], …)	Returns the minimum value in a range of cells, filtered by a set of criteria..
MODE(value1, [value2, ...])	Returns the most commonly occurring value in a dataset.
MODE.MULT(value1, value2)	Returns the most commonly occurring values in a dataset..
MODE.SNGL(value1, [value2, ...])	See MODE 
NEGBINOM.DIST(num_failures, num_successes, prob_success)	See NEGBINOMDIST 
NEGBINOMDIST(num_failures, num_successes, prob_success)	Calculates the probability of drawing a certain number of failures before a certain number of successes given a probability of success in independent trials.
NORM.DIST(x, mean, standard_deviation, cumulative)	See NORMDIST 
NORM.INV(x, mean, standard_deviation)	See NORMINV 
NORM.S.DIST(x)	See NORMSDIST
NORM.S.INV(x)	See NORMSINV
NORMDIST(x, mean, standard_deviation, cumulative)	Returns the value of the normal distribution function (or normal cumulative distribution function) for a specified value, mean, and standard deviation.
NORMINV(x, mean, standard_deviation)	Returns the value of the inverse normal distribution function for a specified value, mean, and standard deviation.
NORMSDIST(x)	Returns the value of the standard normal cumulative distribution function for a specified value.
NORMSINV(x)	Returns the value of the inverse standard normal distribution function for a specified value.
PEARSON(data_y, data_x)	Calculates r, the Pearson product-moment correlation coefficient of a dataset.
PERCENTILE(data, percentile)	Returns the value at a given percentile of a dataset.
PERCENTILE.EXC(data, percentile)	Returns the value at a given percentile of a dataset, exclusive of 0 and 1..
PERCENTILE.INC(data, percentile)	See PERCENTILE
PERCENTRANK(data, value, [significant_digits])	Returns the percentage rank (percentile) of a specified value in a dataset.
PERCENTRANK.EXC(data, value, [significant_digits])	Returns the percentage rank (percentile) from 0 to 1 exclusive of a specified value in a dataset.
PERCENTRANK.INC(data, value, [significant_digits])	Returns the percentage rank (percentile) from 0 to 1 inclusive of a specified value in a dataset.
PERMUTATIONA(number, number_chosen)	Returns the number of permutations for selecting a group of objects (with replacement) from a total number of objects..
PERMUT(n, k)	Returns the number of ways to choose some number of objects from a pool of a given size of objects, considering order.
PHI(x)	The PHI function returns the value of the normal distribution with mean 0 and standard deviation 1..
POISSON(x, mean, cumulative)	See POISSON.DIST
POISSON.DIST(x, mean, [cumulative])	Returns the value of the Poisson distribution function (or Poisson cumulative distribution function) for a specified value and mean.. 
PROB(data, probabilities, low_limit, [high_limit])	Given a set of values and corresponding probabilities, calculates the probability that a value chosen at random falls between two limits.
QUARTILE(data, quartile_number)	Returns a value nearest to a specified quartile of a dataset.
QUARTILE.EXC(data, quartile_number)	Returns value nearest to a given quartile of a dataset, exclusive of 0 and 4..
QUARTILE.INC(data, quartile_number)	See QUARTILE
RANK(value, data, [is_ascending])	Returns the rank of a specified value in a dataset.
RANK.AVG(value, data, [is_ascending])	Returns the rank of a specified value in a dataset. If there is more than one entry of the same value in the dataset, the average rank of the entries will be returned.
RANK.EQ(value, data, [is_ascending])	Returns the rank of a specified value in a dataset. If there is more than one entry of the same value in the dataset, the top rank of the entries will be returned.
RSQ(data_y, data_x)	Calculates the square of r, the Pearson product-moment correlation coefficient of a dataset.
SKEW(value1, value2)	Calculates the skewness of a dataset, which describes the symmetry of that dataset about the mean.
SKEW.P(value1, value2)	Calculates the skewness of a dataset that represents the entire population..
SLOPE(data_y, data_x)	Calculates the slope of the line resulting from linear regression of a dataset.
SMALL(data, n)	Returns the nth smallest element from a data set, where n is user-defined.
STANDARDIZE(value, mean, standard_deviation)	Calculates the normalized equivalent of a random variable given mean and standard deviation of the distribution.
STDEV(value1, [value2, ...])	Calculates the standard deviation based on a sample.
STDEV.P(value1, [value2, ...])	See STDEVP
STDEV.S(value1, [value2, ...])	See STDEV
STDEVA(value1, value2)	Calculates the standard deviation based on a sample, setting text to the value `0`.
STDEVP(value1, value2)	Calculates the standard deviation based on an entire population.
STDEVPA(value1, value2)	Calculates the standard deviation based on an entire population, setting text to the value `0`.
STEYX(data_y, data_x)	Calculates the standard error of the predicted y-value for each x in the regression of a dataset.
TDIST(x, degrees_freedom, tails)	Calculates the probability for Student's t-distribution with a given input (x).
TINV(probability, degrees_freedom)	See T.INV.2T
TRIMMEAN(data, exclude_proportion)	Calculates the mean of a dataset excluding some proportion of data from the high and low ends of the dataset.
TTEST(range1, range2, tails, type)	See T.TEST.
VAR(value1, [value2, ...])	Calculates the variance based on a sample.
VAR.P(value1, [value2, ...]) See VARP
VAR.S(value1, [value2, ...]) See VAR
VARA(value1, value2)	Calculates an estimate of variance based on a sample, setting text to the value `0`.
VARP(value1, value2)	Calculates the variance based on an entire population.
VARPA(value1, value2,...)	Calculates the variance based on an entire population, setting text to the value `0`.
WEIBULL(x, shape, scale, cumulative)	Returns the value of the Weibull distribution function (or Weibull cumulative distribution function) for a specified shape and scale.
WEIBULL.DIST(x, shape, scale, cumulative)	See WEIBULL
ZTEST(data, value, [standard_deviation])	See Z.TEST.
