In [4]:
import os
import re
import requests
from typing import Dict, Any, Optional

#################################
# 1. LANGCHAIN AND SERPAPI TOOLS
#################################
from langchain_community.chat_models import ChatOpenAI
from langchain.agents import AgentType, Tool, initialize_agent

class SerpAPISearchTool:
    """
    Real web search using SerpAPI. 
    See https://serpapi.com/ for usage details and sign up for an API key.
    """
    def __init__(self, api_key: str):
        if not api_key:
            raise ValueError("SerpAPISearchTool requires a valid 'api_key'.")
        self.api_key = api_key

    def search(self, query: str, num_results: int = 3) -> dict:
        """
        Perform a Google search using SerpAPI and return JSON with top results.
        You can parse them in more detail to find exact numeric values.
        """
        params = {
            "engine": "google",
            "q": query,
            "api_key": self.api_key,
            "num": num_results,
        }
        resp = requests.get("https://serpapi.com/search", params=params)
        resp.raise_for_status()
        return resp.json()

def create_serpapi_tool(api_key: str) -> Tool:
    """
    Wrap the SerpAPISearchTool in a LangChain Tool.
    The agent can call it by name: serpapi_search.
    """
    serpapi_client = SerpAPISearchTool(api_key=api_key)

    def _search(query: str) -> str:
        """Return a text block summarizing top snippet results."""
        data = serpapi_client.search(query, num_results=5)
        organic = data.get("organic_results", [])
        if not organic:
            return "No search results found."
        
        # Combine the top results into a single text block
        lines = []
        for i, r in enumerate(organic, 1):
            title = r.get("title", "No title")
            snippet = r.get("snippet", "No snippet")
            link = r.get("link", "No link")
            lines.append(f"Result #{i}: {title}\nSnippet: {snippet}\nLink: {link}\n")
        return "\n".join(lines)

    return Tool(
        name="serpapi_search",
        func=_search,
        description=(
            "Use this to find the latest official UK tax thresholds, rates, "
            "National Insurance info, and more. Input is a search query."
        ),
    )

###########################################
# 2. HELPER FUNCTION: Parse Tax Thresholds
###########################################
def parse_tax_thresholds_from_snippet(snippet_text: str) -> Dict[str, float]:
    """
    Attempt to parse known numeric thresholds from free-form snippet text, 
    using naive regex. This is fragile, but demonstrates the idea.

    Example: if snippet says: 
    'The Personal Allowance for the 2024/25 tax year is £12,570. 
     The basic rate threshold is £37,700, higher rate starts at £50,270, etc.'
    we might capture 12570, 37700, 50270.

    Returns a dictionary with possible keys:
      personal_allowance
      basic_rate_limit
      higher_rate_threshold
      additional_rate_threshold
      ...
    """
    # A naive approach: search for patterns like "personal allowance ... £XX,XXX"
    # and so on. In real usage, you'd refine or do multiple queries for each piece.
    thresholds = {}

    # 1) Personal Allowance
    pa_match = re.search(r"(?:Personal Allowance|personal allowance)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if pa_match:
        # remove comma, convert to float
        val = float(pa_match.group(1).replace(",", ""))
        thresholds["personal_allowance"] = val

    # 2) Basic Rate Threshold
    br_match = re.search(r"(?:basic rate threshold|basic rate limit|basic threshold)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if br_match:
        val = float(br_match.group(1).replace(",", ""))
        thresholds["basic_rate_limit"] = val

    # 3) Higher Rate Threshold
    hr_match = re.search(r"(?:higher rate threshold|higher rate limit|higher threshold)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if hr_match:
        val = float(hr_match.group(1).replace(",", ""))
        thresholds["higher_rate_threshold"] = val

    # 4) Additional Rate Threshold
    ar_match = re.search(r"(?:additional rate threshold|additional rate limit)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if ar_match:
        val = float(ar_match.group(1).replace(",", ""))
        thresholds["additional_rate_threshold"] = val

    # 5) NI lower/upper thresholds - similarly
    # Example matches: 
    # "National Insurance primary threshold is £12,570" 
    # "upper earnings limit is £50,270"
    ni_primary_match = re.search(r"(?:NI|National Insurance).{0,40}?(?:primary threshold|lower threshold)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if ni_primary_match:
        val = float(ni_primary_match.group(1).replace(",", ""))
        thresholds["ni_lower_threshold"] = val

    ni_upper_match = re.search(r"(?:NI|National Insurance).{0,40}?(?:upper earnings limit|upper threshold)\D+(\d{1,3}(?:,\d{3})*|\d+)", snippet_text)
    if ni_upper_match:
        val = float(ni_upper_match.group(1).replace(",", ""))
        thresholds["ni_upper_threshold"] = val

    return thresholds

###########################################
# 3. OUR CUSTOM FUNCTIONS (FOR THE AGENT) 
###########################################
def collect_user_info() -> Dict[str, Any]:
    """
    In real usage, you'd collect from a chat or web form. 
    We'll hardcode for demonstration.
    """
    return {
        "annual_income": 55000,
        "marital_status": "married",
        "spouse_income": 20000,
        "number_of_children": 1,
        "ages_of_children": [3],
        "postcode": "EH1 2NG",   # Scotland example
        "pension_contributions": 3000,
    }

def validate_and_enhance_data(user_data: Dict[str, Any]) -> Dict[str, Any]:
    # Minimal validation:
    required_fields = ["annual_income", "marital_status", "postcode"]
    for f in required_fields:
        if f not in user_data:
            raise ValueError(f"Missing field: {f}")

    # Determine region from postcode or user input
    # Example: EH => Scotland
    if user_data["postcode"].upper().startswith("EH"):
        user_data["region"] = "Scotland"
    else:
        user_data["region"] = "England/Wales/NI"

    return user_data

def dynamic_tax_calculation(user_data: Dict[str, Any], snippet_text: str) -> Dict[str, float]:
    """
    Calculate the user's tax based on real thresholds parsed from snippet_text.
    1) parse thresholds
    2) apply logic (Scotland vs. rest of UK)
    """
    thresholds = parse_tax_thresholds_from_snippet(snippet_text)

    # If we can't find them in the text, set some fallback defaults 
    personal_allowance = thresholds.get("personal_allowance", 12570.0)
    basic_rate_limit = thresholds.get("basic_rate_limit", 37700.0)
    higher_rate_threshold = thresholds.get("higher_rate_threshold", 125140.0)  # example
    # etc.

    annual_income = user_data["annual_income"]
    region = user_data["region"]

    taxable_income = max(0, annual_income - personal_allowance)
    income_tax = 0.0

    if region == "Scotland":
        # For demonstration, let's do a naive band approach.
        # You could parse the exact Scottish bands from snippet_text as well.
        # We'll pretend snippet_text gave us 19%, 20%, 21%, 41%, 46% with breakpoints.
        # Example thresholds:
        scot_starter_end = 14549
        scot_basic_end = 25168
        scot_intermediate_end = 43662
        scot_higher_end = 125140

        if taxable_income <= scot_starter_end:
            income_tax = taxable_income * 0.19
        elif taxable_income <= scot_basic_end:
            step1 = scot_starter_end * 0.19
            step2 = (taxable_income - scot_starter_end) * 0.20
            income_tax = step1 + step2
        elif taxable_income <= scot_intermediate_end:
            step1 = scot_starter_end * 0.19
            step2 = (scot_basic_end - scot_starter_end) * 0.20
            step3 = (taxable_income - scot_basic_end) * 0.21
            income_tax = step1 + step2 + step3
        elif taxable_income <= scot_higher_end:
            step1 = scot_starter_end * 0.19
            step2 = (scot_basic_end - scot_starter_end) * 0.20
            step3 = (scot_intermediate_end - scot_basic_end) * 0.21
            step4 = (taxable_income - scot_intermediate_end) * 0.41
            income_tax = step1 + step2 + step3 + step4
        else:
            # additional or top rate
            step1 = scot_starter_end * 0.19
            step2 = (scot_basic_end - scot_starter_end) * 0.20
            step3 = (scot_intermediate_end - scot_basic_end) * 0.21
            step4 = (scot_higher_end - scot_intermediate_end) * 0.41
            step5 = (taxable_income - scot_higher_end) * 0.46
            income_tax = step1 + step2 + step3 + step4 + step5

    else:
        # For England/Wales/NI (with the data we parsed if available):
        # basic up to basic_rate_limit, higher up to around 125k, then additional
        if taxable_income <= basic_rate_limit:
            # all at 20%
            income_tax = taxable_income * 0.20
        elif taxable_income <= 100000:  # example placeholder for next bracket
            step1 = basic_rate_limit * 0.20
            step2 = (taxable_income - basic_rate_limit) * 0.40
            income_tax = step1 + step2
        else:
            # You could parse the exact threshold for additional rate
            # This is a naive approach
            step1 = basic_rate_limit * 0.20
            step2 = (100000 - basic_rate_limit) * 0.40
            step3 = (taxable_income - 100000) * 0.45
            income_tax = step1 + step2 + step3

    # Simple NI (use snippet_text for thresholds if found)
    ni_lower = thresholds.get("ni_lower_threshold", 12570.0)
    if annual_income > ni_lower:
        ni = (annual_income - ni_lower) * 0.12
    else:
        ni = 0.0

    total_tax = income_tax + ni
    return {
        "income_tax": round(income_tax, 2),
        "national_insurance": round(ni, 2),
        "total": round(total_tax, 2),
    }

def recommend_tax_optimizations(user_data: Dict[str, Any], tax_breakdown: Dict[str, float]) -> Dict[str, Any]:
    """
    Provide suggestions based on real data. In practice, the LLM could
    also search the web for additional strategies or thresholds (e.g., pension caps).
    """
    strategies = []
    income = user_data["annual_income"]
    if income > 50000:
        strategies.append({
            "strategy": "Increase Pension Contributions",
            "details": "Contribute more via salary sacrifice to reduce higher-rate liability.",
            "estimated_savings": 500.0
        })
    if user_data["marital_status"] == "married" and user_data["spouse_income"] < 12570:
        strategies.append({
            "strategy": "Marriage Allowance Transfer",
            "details": "Transfer unused personal allowance from spouse.",
            "estimated_savings": 252.0
        })
    return {"strategies": strategies}

def check_benefits_eligibility(user_data: Dict[str, Any]) -> Dict[str, Any]:
    """
    This might also call the search tool to find the latest thresholds for 
    Universal Credit, Child Benefit, etc. We'll do placeholders here.
    """
    results = []
    if user_data["number_of_children"] > 0:
        results.append({"benefit": "Child Benefit", "eligible": True, "notes": "Watch for High Income Charge above £50k."})
    return {"benefits": results}

def generate_summary(
    user_data: Dict[str, Any],
    snippet_text: str,
    tax_calc: Dict[str, float],
    optimizations: Dict[str, Any],
    benefits: Dict[str, Any]
) -> str:
    summary = "=== Your Real-Time UK Tax Calculation ===\n"
    summary += f"Region: {user_data['region']}\n"
    summary += f"Annual Income: £{user_data['annual_income']}\n"
    summary += "Raw Snippet Data (for transparency):\n"
    summary += snippet_text + "\n\n"
    summary += f"Income Tax: £{tax_calc['income_tax']}\n"
    summary += f"National Insurance: £{tax_calc['national_insurance']}\n"
    summary += f"Total Deductions: £{tax_calc['total']}\n"
    summary += f"Net Income: £{user_data['annual_income'] - tax_calc['total']}\n\n"

    summary += "=== Possible Optimizations ===\n"
    for s in optimizations["strategies"]:
        summary += f"- {s['strategy']}: {s['details']} (Est. Save £{s['estimated_savings']})\n"

    summary += "\n=== Potential Benefits ===\n"
    for b in benefits["benefits"]:
        summary += f"- {b['benefit']}: Eligible? {b['eligible']}. {b.get('notes','')}\n"

    summary += "\nDisclaimer: These figures are best-effort from real-time search data. Always verify with HMRC.\n"
    return summary

##########################################
# 4. BUILD THE LANGCHAIN AGENT & RUN IT
##########################################
def build_agent(openai_api_key: str, serpapi_api_key: str):
    """
    Creates a LangChain agent in ReAct mode that can:
      - serpapi_search (tool)
      - collect_user_info
      - validate_and_enhance_data
      - dynamic_tax_calculation
      - recommend_tax_optimizations
      - check_benefits_eligibility
      - generate_summary
    """
    llm = ChatOpenAI(temperature=0.0, 
                     name="gpt-4o"
                )

    # Real web search tool
    serpapi_tool = create_serpapi_tool(api_key=serpapi_api_key)

    # We'll wrap some of our custom functions in Tools:
    collect_tool = Tool(
        name="collect_user_info",
        func=collect_user_info,
        description="Collect user data for tax/benefits calculation."
    )
    validate_tool = Tool(
        name="validate_and_enhance_data",
        func=validate_and_enhance_data,
        description="Validate user data and infer region from postcode."
    )

    # For dynamic tax calc, we want to pass both user_data and snippet_text. 
    # We'll handle that in the actual code or with a small wrapper.
    def dynamic_tax_calc_wrapper(combined_input: str) -> str:
        """
        Expects combined_input to be something like:
          'user_data=..., snippet_text=...'
        We'll parse that, call dynamic_tax_calculation, and return a string.
        """
        # A quick parse approach:
        # In a real multi-step chain, the agent can store user_data in memory (ConversationBuffer)
        # For demonstration, we'll just do naive parsing:
        import json
        data = json.loads(combined_input)
        user_data_loc = data["user_data"]
        snippet = data["snippet_text"]
        result = dynamic_tax_calculation(user_data_loc, snippet)
        return str(result)

    dynamic_tax_calc_tool = Tool(
        name="dynamic_tax_calculation",
        func=dynamic_tax_calc_wrapper,
        description=(
            "Given user_data and snippet_text with real thresholds, calculates tax breakdown."
            "Input must be a JSON string with keys: user_data (dict) and snippet_text (str)."
            "Output is a JSON-like string with keys [income_tax, national_insurance, total]."
        )
    )

    def recommend_wrapper(json_data: str) -> str:
        parsed = eval(json_data)
        user_data_loc = parsed["user_data"]
        tax_data = parsed["tax_breakdown"]
        return str(recommend_tax_optimizations(user_data_loc, tax_data))

    recommend_tool = Tool(
        name="recommend_tax_optimizations",
        func=recommend_wrapper,
        description=(
            "Given user_data and tax_breakdown, suggest ways to reduce tax. "
            "Input must be a string with: {'user_data':..., 'tax_breakdown':...}. "
            "Output is a JSON-like string with 'strategies'."
        )
    )

    def benefits_wrapper(json_data: str) -> str:
        user_data_loc = eval(json_data)
        return str(check_benefits_eligibility(user_data_loc))

    benefits_tool = Tool(
        name="check_benefits_eligibility",
        func=benefits_wrapper,
        description="Given user_data, returns a JSON-like string of possible benefits."
    )

    def summary_wrapper(json_data: str) -> str:
        """
        Expects: 
        {
          'user_data': ...,
          'snippet_text': ...,
          'tax_calc': ...,
          'optimizations': ...,
          'benefits': ...
        }
        """
        parsed = eval(json_data)
        final = generate_summary(
            parsed["user_data"],
            parsed["snippet_text"],
            parsed["tax_calc"],
            parsed["optimizations"],
            parsed["benefits"]
        )
        return final

    summary_tool = Tool(
        name="generate_summary",
        func=summary_wrapper,
        description=(
            "Given user_data, snippet_text, tax_calc, optimizations, and benefits, "
            "produce a final text summary. Input must be a JSON string with those keys."
        )
    )

    tools = [
        serpapi_tool,
        collect_tool,
        validate_tool,
        dynamic_tax_calc_tool,
        recommend_tool,
        benefits_tool,
        summary_tool
    ]

    agent = initialize_agent(
        tools=tools,
        llm=llm,
        agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True
    )
    return agent


def main():
    """
    Demonstration of a universal approach that:
    1) collects user data,
    2) uses serpapi to find real threshold info,
    3) parses the thresholds,
    4) calculates tax,
    5) suggests optimizations,
    6) checks benefits,
    7) merges into final summary.
    """

    from dotenv import load_dotenv

    load_dotenv()
    OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
    SERPAPI_API_KEY = os.getenv('SERPAPI_API_KEY')
    
    # Provide real keys
    openai_key = os.getenv("OPENAI_API_KEY", OPENAI_API_KEY)
    serpapi_key = os.getenv("SERPAPI_API_KEY", SERPAPI_API_KEY)

    agent = build_agent(openai_key, serpapi_key)

    # We can just do a single user request telling the agent what to do:
    user_prompt = (
        "Hi Agent. Please fetch the real tax thresholds for the 2024/25 UK tax year. "
        "Then collect my user info, validate it, compute my tax, check benefits, "
        "and give me a final summary. Make sure you use the real data from your web search."
    )

    final_output = agent.run(user_prompt)
    print("\n==============================\nFINAL OUTPUT:\n==============================")
    print(final_output)

if __name__ == "__main__":
    main()


ValidationError: 6 validation errors for ChatOpenAI
model_name
  Input should be a valid string [type=string_type, input_value=FieldInfo(default='gpt-3....as_priority=2, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
model_kwargs
  Input should be a valid dictionary [type=dict_type, input_value=FieldInfo(default=Pydanti...class 'dict'>, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/dict_type
openai_api_key
  Input should be a valid string [type=string_type, input_value=FieldInfo(alias='api_key'...as_priority=2, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
openai_api_base
  Input should be a valid string [type=string_type, input_value=FieldInfo(alias='base_url...as_priority=2, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
openai_organization
  Input should be a valid string [type=string_type, input_value=FieldInfo(alias='organiza...as_priority=2, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
max_retries
  Input should be a valid integer [type=int_type, input_value=FieldInfo(default=2, extra={}), input_type=FieldInfo]
    For further information visit https://errors.pydantic.dev/2.10/v/int_type