Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mcp-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ uv run client.py http://localhost:8000/mcp \
| `POLARIS_CLIENT_SECRET` | OAuth client secret. | _unset_ |
| `POLARIS_TOKEN_SCOPE` | OAuth scope string. | _unset_ |
| `POLARIS_TOKEN_URL` | Optional override for the token endpoint URL. | `${POLARIS_BASE_URL}api/catalog/v1/oauth/tokens` |
| `POLARIS_REALM_{realm}_CLIENT_ID` | OAuth client id for a specific realm. | _unset_ |
| `POLARIS_REALM_{realm}_CLIENT_SECRET` | OAuth client secret for a specific realm. | _unset_ |
| `POLARIS_REALM_{realm}_TOKEN_SCOPE` | OAuth scope for a specific realm. | _unset_ |
| `POLARIS_REALM_{realm}_TOKEN_URL` | Token endpoint URL for a specific realm. | _unset_ |
| `POLARIS_REALM_CONTEXT_HEADER_NAME` | Header name used for realm context. | `Polaris-Realm` |
| `POLARIS_TOKEN_REFRESH_BUFFER_SECONDS` | Minimum remaining token lifetime before refreshing in seconds. | `60.0` |
| `POLARIS_HTTP_TIMEOUT_SECONDS` | Default timeout in seconds for all HTTP requests. | `30.0` |
| `POLARIS_HTTP_CONNECT_TIMEOUT_SECONDS` | Timeout in seconds for establishing HTTP connections. | `30.0` |
Expand All @@ -166,8 +171,8 @@ uv run client.py http://localhost:8000/mcp \
| `POLARIS_CONFIG_FILE` | Path to a configuration file containing configuration variables. | `.polaris_mcp.env` in current working directory |



When OAuth variables are supplied, the server automatically acquires and refreshes tokens using the client credentials flow; otherwise a static bearer token is used if provided.
Realm-specific variables (e.g., `POLARIS_REALM_${realm}_CLIENT_ID`) override the global settings for a given realm for client ID, client secret, token scope, and token URL. If realm-specific credentials are provided but incomplete, the server will not fall back to global credentials for that realm.

## Tools

Expand Down
116 changes: 79 additions & 37 deletions mcp-server/polaris_mcp/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from __future__ import annotations

import json
import os
import threading
import time
from abc import ABC, abstractmethod
from typing import Optional
from urllib.parse import urlencode
from urllib.parse import urlencode, urljoin

import urllib3

Expand All @@ -35,7 +36,7 @@ class AuthorizationProvider(ABC):
"""Return Authorization header values for outgoing requests."""

@abstractmethod
def authorization_header(self) -> Optional[str]: ...
def authorization_header(self, realm: Optional[str] = None) -> Optional[str]: ...


class StaticAuthorizationProvider(AuthorizationProvider):
Expand All @@ -45,7 +46,7 @@ def __init__(self, token: Optional[str]) -> None:
value = (token or "").strip()
self._header = f"Bearer {value}" if value else None

def authorization_header(self) -> Optional[str]:
def authorization_header(self, realm: Optional[str] = None) -> Optional[str]:
return self._header


Expand All @@ -54,59 +55,100 @@ class ClientCredentialsAuthorizationProvider(AuthorizationProvider):

def __init__(
self,
token_endpoint: str,
client_id: str,
client_secret: str,
scope: Optional[str],
base_url: str,
http: urllib3.PoolManager,
refresh_buffer_seconds: float,
timeout: urllib3.Timeout,
) -> None:
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._scope = scope
self._base_url = base_url
self._http = http
self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
self._timeout = timeout
self._lock = threading.Lock()
self._cached: Optional[tuple[str, float]] = None # (token, expires_at_epoch)
self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
# {realm: (token, expires_at_epoch)}
self._cached: dict[str, tuple[str, float]] = {}

def authorization_header(self) -> Optional[str]:
token = self._current_token()
def authorization_header(self, realm: Optional[str] = None) -> Optional[str]:
token = self._get_token_from_realm(realm)
return f"Bearer {token}" if token else None

def _current_token(self) -> Optional[str]:
now = time.time()
cached = self._cached
if not cached or cached[1] - self._refresh_buffer_seconds <= now:
with self._lock:
cached = self._cached
if (
not cached
or cached[1] - self._refresh_buffer_seconds <= time.time()
):
self._cached = cached = self._fetch_token()
return cached[0] if cached else None

def _fetch_token(self) -> tuple[str, float]:
def _get_token_from_realm(self, realm: Optional[str]) -> Optional[str]:
def needs_refresh(cached):
return (
cached is None
or cached[1] - self._refresh_buffer_seconds <= time.time()
)

cache_key = realm or ""
token = self._cached.get(cache_key)
# Token not expired
if not needs_refresh(token):
return token[0]
# Acquire lock and verify again if token expired
with self._lock:
token = self._cached.get(cache_key)
if needs_refresh(token):
credentials = self._get_credentials_from_realm(realm)
if not credentials:
return None
token = self._fetch_token(realm, credentials)
self._cached[cache_key] = token
return token[0] if token else None

def _get_credentials_from_realm(
self, realm: Optional[str]
) -> Optional[dict[str, str]]:
def get_env(key: str) -> Optional[str]:
val = os.getenv(key)
return val.strip() or None if val else None

def load_creds(realm: Optional[str] = None) -> dict[str, Optional[str]]:
prefix = f"POLARIS_REALM_{realm}_" if realm else "POLARIS_"
return {
"client_id": get_env(f"{prefix}CLIENT_ID"),
"client_secret": get_env(f"{prefix}CLIENT_SECRET"),
"scope": get_env(f"{prefix}TOKEN_SCOPE"),
"token_url": get_env(f"{prefix}TOKEN_URL"),
}

# Only use realm-specific credentials
if realm:
creds = load_creds(realm)
if creds["client_id"] and creds["client_secret"]:
return creds
return None
# No realm specified, use global credentials
creds = load_creds()
if creds["client_id"] and creds["client_secret"]:
return creds
return None

def _fetch_token(
self, realm: Optional[str], credentials: dict[str, str]
) -> tuple[str, float]:
token_url = credentials.get("token_url") or urljoin(
self._base_url, "api/catalog/v1/oauth/tokens"
)
payload = {
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
"client_id": credentials["client_id"],
"client_secret": credentials["client_secret"],
}
if self._scope:
payload["scope"] = self._scope
if credentials.get("scope"):
payload["scope"] = credentials["scope"]

encoded = urlencode(payload)
header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "Polaris-Realm")
headers = {"Content-Type": "application/x-www-form-urlencoded"}
if realm:
headers[header_name] = realm
response = self._http.request(
"POST",
self._token_endpoint,
token_url,
body=encoded,
headers={"Content-Type": "application/x-www-form-urlencoded"},
headers=headers,
timeout=self._timeout,
)

if response.status != 200:
raise RuntimeError(
f"OAuth token endpoint returned {response.status}: {response.data.decode('utf-8', errors='ignore')}"
Expand All @@ -132,7 +174,7 @@ def _fetch_token(self) -> tuple[str, float]:


class _NoneAuthorizationProvider(AuthorizationProvider):
def authorization_header(self) -> Optional[str]:
def authorization_header(self, realm: Optional[str] = None) -> Optional[str]:
return None


Expand Down
9 changes: 8 additions & 1 deletion mcp-server/polaris_mcp/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import json
import os
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlencode, urljoin, urlsplit, urlunsplit, quote

Expand Down Expand Up @@ -230,6 +231,7 @@ def call(self, arguments: Any) -> ToolExecutionResult:
query_params = arguments.get("query")
headers_param = arguments.get("headers")
body_node = arguments.get("body")
realm = arguments.get("realm")

query = query_params if isinstance(query_params, dict) else None
headers = headers_param if isinstance(headers_param, dict) else None
Expand All @@ -238,9 +240,14 @@ def call(self, arguments: Any) -> ToolExecutionResult:

header_values = _merge_headers(headers)
if not any(name.lower() == "authorization" for name in header_values):
token = self._authorization.authorization_header()
token = self._authorization.authorization_header(realm)
if token:
header_values["Authorization"] = token
header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "Polaris-Realm")
if realm and not any(
name.lower() == header_name.lower() for name in header_values
):
header_values[header_name] = realm

body_text = _serialize_body(body_node)
if body_text is not None and not any(
Expand Down
39 changes: 24 additions & 15 deletions mcp-server/polaris_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import argparse
import os
from typing import Any, Mapping, MutableMapping, Sequence, Optional
from urllib.parse import urljoin, urlparse
from urllib.parse import urlparse

import urllib3
from fastmcp import FastMCP
Expand Down Expand Up @@ -164,6 +164,7 @@ def polaris_iceberg_table(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
table_tool,
Expand All @@ -177,6 +178,7 @@ def polaris_iceberg_table(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"namespace": _normalize_namespace,
Expand All @@ -198,6 +200,7 @@ def polaris_namespace_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
namespace_tool,
Expand All @@ -210,6 +213,7 @@ def polaris_namespace_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"namespace": _normalize_namespace,
Expand All @@ -231,6 +235,7 @@ def polaris_principal_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
principal_tool,
Expand All @@ -241,6 +246,7 @@ def polaris_principal_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"query": _copy_mapping,
Expand All @@ -262,6 +268,7 @@ def polaris_principal_role_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
principal_role_tool,
Expand All @@ -273,6 +280,7 @@ def polaris_principal_role_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"query": _copy_mapping,
Expand All @@ -293,6 +301,7 @@ def polaris_catalog_role_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
catalog_role_tool,
Expand All @@ -305,6 +314,7 @@ def polaris_catalog_role_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"query": _copy_mapping,
Expand All @@ -326,6 +336,7 @@ def polaris_policy_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
policy_tool,
Expand All @@ -339,6 +350,7 @@ def polaris_policy_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"namespace": _normalize_namespace,
Expand All @@ -359,6 +371,7 @@ def polaris_catalog_request(
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
catalog_tool,
Expand All @@ -368,6 +381,7 @@ def polaris_catalog_request(
"query": query,
"headers": headers,
"body": body,
"realm": realm,
},
transforms={
"query": _copy_mapping,
Expand Down Expand Up @@ -482,23 +496,21 @@ def parse_timeout(raw: Optional[str]) -> Optional[float]:


def _resolve_authorization_provider(
base_url: str, http: urllib3.PoolManager, timeout: urllib3.Timeout
base_url: str,
http: urllib3.PoolManager,
timeout: urllib3.Timeout,
) -> AuthorizationProvider:
token = _resolve_token()
if token:
return StaticAuthorizationProvider(token)

client_id = _first_non_blank(
os.getenv("POLARIS_CLIENT_ID"),
)
client_secret = _first_non_blank(
os.getenv("POLARIS_CLIENT_SECRET"),
client_id = _first_non_blank(os.getenv("POLARIS_CLIENT_ID"))
client_secret = _first_non_blank(os.getenv("POLARIS_CLIENT_SECRET"))
has_realm_credentials = any(
key.startswith("POLARIS_REALM_") for key in os.environ.keys()
)

if client_id and client_secret:
scope = _first_non_blank(os.getenv("POLARIS_TOKEN_SCOPE"))
token_url = _first_non_blank(os.getenv("POLARIS_TOKEN_URL"))
endpoint = token_url or urljoin(base_url, "api/catalog/v1/oauth/tokens")
if client_id and client_secret or has_realm_credentials:
refresh_buffer_seconds = DEFAULT_TOKEN_REFRESH_BUFFER_SECONDS
refresh_buffer_seconds_str = os.getenv("POLARIS_TOKEN_REFRESH_BUFFER_SECONDS")
if refresh_buffer_seconds_str:
Expand All @@ -507,10 +519,7 @@ def _resolve_authorization_provider(
except ValueError:
pass
return ClientCredentialsAuthorizationProvider(
token_endpoint=endpoint,
client_id=client_id,
client_secret=client_secret,
scope=scope,
base_url=base_url,
http=http,
refresh_buffer_seconds=refresh_buffer_seconds,
timeout=timeout,
Expand Down
Loading