In [None]:
import logging
import re
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from enum import StrEnum
from typing import Any, Callable, Optional
import requests
import json
import os
import time
from typing import Dict, List, Optional, Union
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

## logger

In [None]:
def configure_logging():
    logging.basicConfig(
        filename="app.log",
        filemode="w",
        level=logging.DEBUG,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )
    # Add console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(
        logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    )
    logging.getLogger().addHandler(console_handler)

## odata_filters

In [None]:
# AST NODE TYPES
class Expr:
    def to_odata(self) -> str:
        raise NotImplementedError

    def __and__(self, other: "Expr") -> "BinaryOp":
        return BinaryOp(self, "and", other)

    def __or__(self, other: "Expr") -> "BinaryOp":
        return BinaryOp(self, "or", other)

    def __invert__(self) -> "UnaryOp":
        return UnaryOp("not", self)


@dataclass(frozen=True)
class Field(Expr):
    path: str

    # comparisons
    def eq(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "eq", literal(val))

    def ne(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "ne", literal(val))

    def gt(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "gt", literal(val))

    def ge(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "ge", literal(val))

    def lt(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "lt", literal(val))

    def le(self, val: Any) -> "BinaryOp":
        return BinaryOp(self, "le", literal(val))

    # string functions
    def contains(self, val: Any) -> "Func":
        return Func("contains", [self, literal(val)])

    def startswith(self, val: Any) -> "Func":
        return Func("startswith", [self, literal(val)])

    def endswith(self, val: Any) -> "Func":
        return Func("endswith", [self, literal(val)])

    # emulate IN as disjunction
    def isin(self, values: List[Any]) -> "Expr":
        if not values:
            # empty IN -> false; represent as (1 eq 0)
            return BinaryOp(Literal(1), "eq", Literal(0))
        expr = None
        for v in values:
            clause = BinaryOp(self, "eq", literal(v))
            expr = clause if expr is None else BinaryOp(expr, "or", clause)
        return expr

    def to_odata(self) -> str:
        # Convert dot notation to forward slash for OData v4 compliance
        # Input: "workers.workAssignments.reportsTo.positionID"
        # Output: "workers/workAssignments/reportsTo/positionID"
        return self.path.replace(".", "/")


@dataclass(frozen=True)
class Literal(Expr):
    value: Any

    def to_odata(self) -> str:
        v = self.value
        if v is None:
            return "null"
        if isinstance(v, bool):
            return "true" if v else "false"
        if isinstance(v, (int, float)):
            return str(v)
        # Default: string; escape single quotes by doubling them
        s = str(v).replace("'", "''")
        return f"'{s}'"


def literal(v: Any) -> Literal:
    return Literal(v)


@dataclass(frozen=True)
class Func(Expr):
    name: str
    args: List[Expr]

    def to_odata(self) -> str:
        args_s = ", ".join(a.to_odata() for a in self.args)
        return f"{self.name}({args_s})"


@dataclass(frozen=True)
class BinaryOp(Expr):
    left: Expr
    # * Could be replaced with enum
    op: str  #'eq','ne','gt','ge','lt','le','and','or'
    right: Expr

    def to_odata(self) -> str:
        # Parentheses ensure correct precedence in mixed expressions
        return f"({self.left.to_odata()} {self.op} {self.right.to_odata()})"


@dataclass(frozen=True)
class UnaryOp(Expr):
    op: str  # 'not'
    expr: Expr

    def to_odata(self) -> str:
        return f"({self.op} {self.expr.to_odata()})"


# ---------------------------
# Public facade
# ---------------------------


class FilterExpression(Expr):
    """
    Public wrapper to create and parse filter expressions.
    Behaves like an Expr and delegates to_odata to its underlying node.
    """

    def __init__(self, node: Expr):
        self._node = node

    # faÃ§ade pass-through
    def to_odata(self) -> str:
        return self._node.to_odata()

    # convenience constructors
    @staticmethod
    def field(path: str) -> Field:
        return Field(path)

    # parse a limited OData subset into an AST
    @staticmethod
    def from_string(s: str) -> "FilterExpression":
        node = _FilterParser(s).parse()
        return FilterExpression(node)

    # combinators keep returning FilterExpression
    def __and__(self, other: Expr) -> "FilterExpression":
        return FilterExpression(BinaryOp(self._node, "and", _unwrap(other)))

    def __or__(self, other: Expr) -> "FilterExpression":
        return FilterExpression(BinaryOp(self._node, "or", _unwrap(other)))

    def __invert__(self) -> "FilterExpression":
        return FilterExpression(UnaryOp("not", self._node))


def _unwrap(e: Union[Expr, FilterExpression]) -> Expr:
    return e._node if isinstance(e, FilterExpression) else e


# ---------------------------
# Minimal OData filter parser
# Supports:
#   - parentheses
#   - and/or/not
#   - eq, ne, gt, ge, lt, le
#   - contains(), startswith(), endswith()
#   - identifiers with dot (field paths), string/number/bool/null
# ---------------------------

_TOKEN_SPEC = [
    ("WS", r"[ \t\n\r]+"),
    ("LPAREN", r"\("),
    ("RPAREN", r"\)"),
    ("COMMA", r","),
    ("OP", r"\b(eq|ne|gt|ge|lt|le|and|or|not)\b"),
    ("FUNC", r"\b(contains|startswith|endswith)\b"),
    ("BOOL", r"\b(true|false)\b"),
    ("NULL", r"\bnull\b"),
    ("NUMBER", r"-?\d+(\.\d+)?"),
    ("IDENT", r"[A-Za-z_][A-Za-z0-9_\.]*"),
    ("STRING", r"'([^']|'')*'"),
]

_TOKEN_RE = re.compile(
    "|".join(f"(?P<{name}>{pat})" for name, pat in _TOKEN_SPEC), re.IGNORECASE
)


class _Token:
    def __init__(self, typ: str, val: str):
        self.type = typ
        self.value = val


class _FilterParser:
    def __init__(self, text: str):
        self.tokens = [t for t in self._tokenize(text)]
        self.pos = 0

    def _tokenize(self, text):
        for m in _TOKEN_RE.finditer(text):
            typ = m.lastgroup
            val = m.group(typ)
            if typ == "WS":
                continue
            yield _Token(typ, val)
        # implicit EOF

    def _peek(self) -> Optional[_Token]:
        return self.tokens[self.pos] if self.pos < len(self.tokens) else None

    def _eat(self, typ: str) -> _Token:
        tok = self._peek()
        if not tok or tok.type != typ:
            raise ValueError(f"Expected {typ}, found {tok.type if tok else 'EOF'}")
        self.pos += 1
        return tok

    def _match(self, typ: str) -> Optional[_Token]:
        tok = self._peek()
        if tok and tok.type == typ:
            self.pos += 1
            return tok
        return None

    # Grammar (Pratt-ish recursive descent):
    # expr  := or_expr
    # or_expr := and_expr ('or' and_expr)*
    # and_expr := not_expr ('and' not_expr)*
    # not_expr := ['not'] cmp_expr
    # cmp_expr := primary (OP primary)?
    # primary := FUNC '(' arg_list ')' | '(' expr ')' | literal | field
    # arg_list := expr (',' expr)*
    def parse(self) -> Expr:
        expr = self._parse_or()
        if self._peek():
            raise ValueError(f"Unexpected token: {self._peek().value}")
        return expr

    def _parse_or(self) -> Expr:
        node = self._parse_and()
        while (
            self._peek()
            and self._peek().type == "OP"
            and self._peek().value.lower() == "or"
        ):
            self._eat("OP")
            rhs = self._parse_and()
            node = BinaryOp(node, "or", rhs)
        return node

    def _parse_and(self) -> Expr:
        node = self._parse_not()
        while (
            self._peek()
            and self._peek().type == "OP"
            and self._peek().value.lower() == "and"
        ):
            self._eat("OP")
            rhs = self._parse_not()
            node = BinaryOp(node, "and", rhs)
        return node

    def _parse_not(self) -> Expr:
        if (
            self._peek()
            and self._peek().type == "OP"
            and self._peek().value.lower() == "not"
        ):
            self._eat("OP")
            return UnaryOp("not", self._parse_cmp())
        return self._parse_cmp()

    def _parse_cmp(self) -> Expr:
        left = self._parse_primary()
        tok = self._peek()
        if (
            tok
            and tok.type == "OP"
            and tok.value.lower() in {"eq", "ne", "gt", "ge", "lt", "le"}
        ):
            op = tok.value.lower()
            self._eat("OP")
            right = self._parse_primary()
            return BinaryOp(left, op, right)
        return left

    def _parse_primary(self) -> Expr:
        tok = self._peek()
        if not tok:
            raise ValueError("Unexpected EOF")

        if tok.type == "FUNC":
            name = tok.value.lower()
            self._eat("FUNC")
            self._eat("LPAREN")
            args = [self._parse()]
            while self._match("COMMA"):
                args.append(self._parse())
            self._eat("RPAREN")
            return Func(name, args)

        if tok.type == "LPAREN":
            self._eat("LPAREN")
            node = self._parse_or()
            self._eat("RPAREN")
            return node

        if tok.type == "IDENT":
            self._eat("IDENT")
            return Field(tok.value)

        if tok.type == "STRING":
            self._eat("STRING")
            # unescape doubled single quotes
            inner = tok.value[1:-1].replace("''", "'")
            return Literal(inner)

        if tok.type == "NUMBER":
            self._eat("NUMBER")
            return Literal(float(tok.value) if "." in tok.value else int(tok.value))

        if tok.type == "BOOL":
            self._eat("BOOL")
            return Literal(tok.value.lower() == "true")

        if tok.type == "NULL":
            self._eat("NULL")
            return Literal(None)

        raise ValueError(f"Unexpected token: {tok.value}")


if __name__ == "__main__":
    # Example: Building filters programmatically with the fluent API
    print("=== Programmatic Filter Building ===\n")

    # Simple equality filter
    filter1 = FilterExpression.field("worker.person.legalName.givenName").eq("John")
    print(f"givenName = 'John':\n  {filter1.to_odata()}\n")

    # Comparison operators
    filter2 = FilterExpression.field("employee.hireDate").ge("2020-01-01")
    print(f"hireDate >= '2020-01-01':\n  {filter2.to_odata()}\n")

    # String functions
    filter3 = FilterExpression.field("worker.person.legalName.familyName").contains(
        "Smith"
    )
    print(f"familyName contains 'Smith':\n  {filter3.to_odata()}\n")

    # Complex expressions with and/or operators (wrap in FilterExpression)
    filter4 = FilterExpression(
        FilterExpression.field("worker.person.legalName.givenName").eq("John")
    ) & FilterExpression(
        FilterExpression.field("worker.person.legalName.familyName").eq("Doe")
    )
    print(f"givenName = 'John' AND familyName = 'Doe':\n  {filter4.to_odata()}\n")

    # Complex expression with or
    filter5 = FilterExpression(
        FilterExpression.field("department").eq("Engineering")
    ) | FilterExpression(FilterExpression.field("department").eq("Sales"))
    print(
        f"department = 'Engineering' OR department = 'Sales':\n  {filter5.to_odata()}\n"
    )

    # Using isin for multiple values
    filter6 = FilterExpression.field("status").isin(["Active", "OnLeave", "Pending"])
    print(f"status IN ('Active', 'OnLeave', 'Pending'):\n  {filter6.to_odata()}\n")

    # Using not operator (wrap in FilterExpression)
    filter7 = ~FilterExpression(FilterExpression.field("isTerminated").eq(True))
    print(f"NOT isTerminated = true:\n  {filter7.to_odata()}\n")

    print("=== Parsing OData Filter Strings ===\n")

    # Parse existing OData filter strings
    odata_str = (
        "(worker.person.legalName.givenName eq 'John') and (hireDate ge '2020-01-01')"
    )
    try:
        filter8 = FilterExpression.from_string(odata_str)
        print(f"Parsed filter:\n  Input:  {odata_str}")
        print(f"  Output: {filter8.to_odata()}\n")
    except Exception as e:
        print(f"Parse error: {e}\n")

## sessions

In [None]:
logger = logging.getLogger(__name__)


class RequestMethod(StrEnum):
    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"


@dataclass
class ApiSession:
    session: requests.Session
    cert: tuple[str, str]
    get_headers: Optional[Callable[[], dict]] = None
    headers: Optional[dict] = None
    params: Optional[dict] = None
    timeout: int = 30
    data: Optional[Any] = None

    def __post_init__(self):
        if self.get_headers is None:
            # Default to empty header generation
            self.get_headers = lambda: {}
        if self.params is None:
            self.params = {}

    def set_params(self, params: dict):
        self.params = params

    def set_data(self, data: Any):
        self.data = data

    def _get_request_function(self, method: RequestMethod) -> Callable:
        match method:
            case RequestMethod.GET:
                return self.session.get
            case RequestMethod.POST:
                return self.session.post
            case RequestMethod.PUT:
                return self.session.put
            case RequestMethod.DELETE:
                return self.session.delete

        raise ValueError(f"Unsupported method {method}")

    def _request(
        self, url: str, method: Optional[RequestMethod] = RequestMethod.GET
    ) -> requests.Response:
        """Execute HTTP request with specified method, headers, params, and optional data.

        Args:
            url: The request URL
            method: HTTP method (GET, POST, PUT, DELETE)

        Returns:
            requests.Response object

        Raises:
            requests.RequestException: If request fails
        """
        request_fn = self._get_request_function(method)
        # Generate headers on call time for up-to-date token
        headers = self.get_headers()
        try:
            kwargs = {
                "headers": headers,
                "params": self.params,
                "cert": self.cert,
                "timeout": self.timeout,
            }
            if self.data is not None:
                kwargs["json"] = self.data

            response = request_fn(url, **kwargs)
            response.raise_for_status()

        except requests.RequestException as e:
            logger.error(
                f"Request failed for {method} request to url: {url} with params {self.params}:\n{e}"
            )
            raise

        return response

    def get(self, url: str) -> requests.Response:
        return self._request(url, RequestMethod.GET)

    def post(self, url: str, data: Optional[Any] = None) -> requests.Response:
        if data is not None:
            self.set_data(data)
        return self._request(url, RequestMethod.POST)

    def put(self, url: str, data: Optional[Any] = None) -> requests.Response:
        if data is not None:
            self.set_data(data)
        return self._request(url, RequestMethod.PUT)

    def delete(self, url: str) -> requests.Response:
        return self._request(url, RequestMethod.DELETE)

## client

In [None]:
logger = logging.getLogger(__name__)

# Constants
DEFAULT_TIMEOUT = 30
TOKEN_BUFFER_SECONDS = 300  # Refresh token 5 minutes before expiration


class AdpApiClient:
    def __init__(
        self, client_id: str, client_secret: str, cert_path: str, key_path: str
    ):
        if not all([client_id, client_secret, cert_path, key_path]):
            raise ValueError("All credentials and paths must be provided.")
        if not os.path.exists(cert_path) or not os.path.exists(key_path):
            raise FileNotFoundError("Certificate or key file not found.")

        self.client_id = client_id
        self.client_secret = client_secret
        self.cert_path = cert_path
        self.key_path = key_path
        self.cert = (cert_path, key_path)
        self.session = requests.Session()
        self._setup_retry_strategy()

        # Token expiration tracking
        self.token = None
        self.token_expires_at = 0

    @property
    def payload(self) -> Dict[str, str]:
        return {
            "grant_type": "client_credentials",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
        }

    @property
    def base_url(self) -> str:
        return "https://api.adp.com"

    def _setup_retry_strategy(self, retries: int = 3, backoff_factor: float = 0.5):
        """Configure retry strategy with exponential backoff for HTTP requests."""
        retry_strategy = Retry(
            total=retries,
            backoff_factor=backoff_factor,
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=["GET", "POST"],
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        self.session.mount("http://", adapter)
        self.session.mount("https://", adapter)
        logger.debug(
            f"Retry strategy configured: {retries} retries with {backoff_factor}s backoff"
        )

    def _is_token_expired(self) -> bool:
        """Check if token is expired or will expire soon."""
        return time.time() >= self.token_expires_at - TOKEN_BUFFER_SECONDS

    def _get_token(self, timeout: int = DEFAULT_TIMEOUT) -> str:
        logger.debug("Requesting Token from ADP Accounts endpoint")
        TOKEN_URL = "https://accounts.adp.com/auth/oauth/v2/token"
        try:
            response = self.session.post(
                TOKEN_URL,
                data=self.payload,
                cert=self.cert,
                timeout=timeout,
            )
            response.raise_for_status()
            token_json = response.json()
            token = token_json.get("access_token")
            if not token:
                raise ValueError("No access token in response")

            # Track token expiration
            expires_in = token_json.get("expires_in", 3600)  # Default 1 hour
            self.token_expires_at = time.time() + expires_in
            logger.info(f"Token Acquired (expires in {expires_in}s)")
            return token
        except requests.RequestException as e:
            logger.error(f"Token request failed: {e}")
            raise

    def _ensure_valid_token(self, timeout: int = DEFAULT_TIMEOUT):
        """Refresh token if expired."""
        if self.token is None or self._is_token_expired():
            logger.debug("Token expired, refreshing...")
            self.token = self._get_token(timeout)

    def _get_headers(self, masked: bool = True) -> Dict[str, str]:
        """Build request headers with Bearer token and masking preference."""
        # * May need to be tweaked in the future if OData calls or other forms are needed. Not necessary for MVP
        accept = "application/json"
        if not masked:
            accept += ";masked=false"
            logging.debug(f"Calling _get_headers with accept = {accept}")

        headers = {
            "Authorization": f"Bearer {self.token}",
            "Accept": accept,
        }

        return headers

    def get_masked_headers(self) -> Dict[str, str]:
        return self._get_headers(True)

    def get_unmasked_headers(self) -> Dict[str, str]:
        return self._get_headers(False)

    def _handle_filters(self, filters: Optional[Union[str, FilterExpression]] = None) -> str:
        """Convert filter input (string or FilterExpression) to OData string.
        
        Args:
            filters: Filter as string or FilterExpression object, or None
            
        Returns:
            OData filter string, or empty string if no filters
        """
        if filters is None:
            return ""
        elif isinstance(filters, str):
            try:
                filters = FilterExpression.from_string(filters)
            except ValueError:
                logger.error(f'Error parsing filter expression: {filters}')
                raise
        
        # Remove outer parentheses added by BinaryOp if present
        odata_str = filters.to_odata()
        if odata_str.startswith("(") and odata_str.endswith(")"):
            odata_str = odata_str[1:-1]
        return odata_str
    def _clean_endpoint(self, endpoint: str) -> str:
        starts_with_base = endpoint.startswith(self.base_url)
        starts_with_path = endpoint.startswith("/")

        if not (starts_with_base or starts_with_path):
            logger.error(f"Incorrect Endpoint Received {endpoint}")
            raise ValueError(f"Incorrect Endpoint Received: {endpoint}")

        if starts_with_base:
            endpoint = endpoint.split(self.base_url)[1]
            logger.warning(
                "Full URL Specification not needed, prefer to use the endpoint string.\n"
                f"(Ex: Prefer {endpoint} over {self.base_url}{endpoint})."
            ) 

        return endpoint
        
    def call_endpoint(
        self,
        endpoint: str,
        select: Optional[List[str]] = None,
        filters: Optional[Union[str, FilterExpression]] = None,
        masked: Optional[bool] = True,
        timeout: Optional[int] = DEFAULT_TIMEOUT,
        page_size: Optional[int] = 100,
        max_requests: Optional[int] = None,
    ) -> List[Dict]:
        """Call any Registered ADP Endpoint

        Args:
            endpoint (str): API Endpoint or qualified URL to call
            select (List[str]): Table Columns to pull
            masked (bool, optional): Mask Sensitive Columns Containing Personally Identifiable Information. Defaults to True.
            timeout (int, optional): Time to wait on. Defaults to 30.
            page_size (int, optional): Amount of records to pull per API call (max 100). Defaults to 100.
            max_requests (Optional[int], optional): Maximum number of requests to make (for quick testing). Defaults to None.

        Raises:
            ValueError: When given an endpoint not following the call convention

        Returns:
            List[Dict]: The collection of API responses
        """

        # Request Cleanup and Validation Logic
        if page_size > 100:
            logger.warning(
                "Page size > 100 not supported by API endpoint. Limiting to 100."
            )
            page_size = 100

        

        # Output/Request Initialization
        endpoint = self._clean_endpoint(endpoint)
        url = self.base_url + endpoint
        filter_param = self._handle_filters(filters)
        # Populate here instead of mutable default arguments
        if select is None:
            select = []
        select = ",".join(select)
        output = []
        skip = 0

        if masked:
            get_headers_fn = self.get_masked_headers
        else:
            get_headers_fn = self.get_unmasked_headers

        call_session = ApiSession(
            self.session, self.cert, get_headers_fn, timeout=timeout
        )
        
        params = {'$top': page_size}
        if select:
            logging.debug(f'Restricting OData Selection to {select}')
            params['$select'] = select
        if filter_param:
            logging.debug(f'Filtering Results according to OData query: {filter_param}')
            params['$filter'] = filter_param
        
        while True:
            params['$skip'] = skip
            call_session.set_params(params)
            self._ensure_valid_token(timeout)
            response = call_session.get(url)

            if response.status_code == 204:
                logger.debug("End of pagination reached (204 No Content)")
                break

            try:
                data = response.json()
                output.append(data)

            except json.JSONDecodeError as e:
                logger.error(f"Failed to parse JSON response: {e}")
                raise

            if max_requests is not None and len(output) >= max_requests:
                logger.debug(f"Max Requests reached: {max_requests}")
                break
            skip += page_size

        return output

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit - cleanup session."""
        self.session.close()
        logger.debug("Session closed")
        return False