diff --git a/backend/connector_auth_v2/constants.py b/backend/connector_auth_v2/constants.py index 886968d87b..2508df7edd 100644 --- a/backend/connector_auth_v2/constants.py +++ b/backend/connector_auth_v2/constants.py @@ -16,3 +16,21 @@ class SocialAuthConstants: GOOGLE_OAUTH = "google-oauth2" GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S" + + +# OAuth token-specific keys safe to merge across connectors sharing the same +# (provider, uid). Anything outside this set (form fields like site_url, +# drive_id, or provider-specific enrichment stored in ConnectorAuth.extra_data) +# must NOT leak between connectors. +OAUTH_TOKEN_KEYS: frozenset[str] = frozenset( + { + SocialAuthConstants.ACCESS_TOKEN, + SocialAuthConstants.REFRESH_TOKEN, + SocialAuthConstants.TOKEN_TYPE, + SocialAuthConstants.EXPIRES, + SocialAuthConstants.AUTH_TIME, + SocialAuthConstants.REFRESH_AFTER, + "expires_in", + "scope", + } +) diff --git a/backend/connector_auth_v2/models.py b/backend/connector_auth_v2/models.py index cd36e75cce..a92630c41a 100644 --- a/backend/connector_auth_v2/models.py +++ b/backend/connector_auth_v2/models.py @@ -10,7 +10,7 @@ from social_django.models import AbstractUserSocialAuth, DjangoStorage from social_django.strategy import DjangoStrategy -from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.constants import OAUTH_TOKEN_KEYS, SocialAuthConstants from connector_auth_v2.pipeline.google import GoogleAuthHelper logger = logging.getLogger(__name__) @@ -96,11 +96,27 @@ def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bo refreshed_token = True related_connector_instances = self.connectorinstance_set.all() for connector_instance in related_connector_instances: - connector_instance.connector_metadata = self.extra_data + # Whitelist-merge: only OAuth token keys flow from the shared + # extra_data into each sibling's metadata. Non-token keys (form + # fields like site_url, drive_id, or provider-enrichment stored + # in extra_data) must NOT leak between connectors sharing the + # same (provider, uid). + existing_metadata = connector_instance.connector_metadata or {} + token_updates = { + key: self.extra_data[key] + for key in OAUTH_TOKEN_KEYS + if self.extra_data.get(key) is not None + } + connector_instance.connector_metadata = { + **existing_metadata, + **token_updates, + } connector_instance.save() logger.info( - f"Refreshed access token for connector {connector_instance.id}, " - f"provider: {self.provider}, uid: {self.uid}" + "Refreshed access token for connector %s, provider: %s, uid: %s", + connector_instance.id, + self.provider, + self.uid, ) return self.extra_data, refreshed_token diff --git a/backend/connector_v2/fields.py b/backend/connector_v2/fields.py index 2a0f18c549..d3a65d91b7 100644 --- a/backend/connector_v2/fields.py +++ b/backend/connector_v2/fields.py @@ -1,7 +1,7 @@ import logging from datetime import datetime -from connector_auth_v2.constants import SocialAuthConstants +from connector_auth_v2.constants import OAUTH_TOKEN_KEYS, SocialAuthConstants from connector_auth_v2.models import ConnectorAuth from django.db import models @@ -25,17 +25,28 @@ def from_db_value(self, value, expression, connection): # type: ignore refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT ) if datetime.now() > refresh_after: - metadata = self._refresh_tokens(provider, uid) + metadata = self._refresh_tokens(provider, uid, metadata) return metadata - def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]: - """Retrieves PSA object and refreshes the token if necessary.""" + def _refresh_tokens( + self, provider: str, uid: str, existing_metadata: dict[str, str] + ) -> dict[str, str]: + """Retrieves PSA object and refreshes the token if necessary. + + Whitelist-merges refreshed token fields over existing metadata so + per-instance form fields (e.g. site_url, drive_id) are preserved on + read and non-token keys from the shared ConnectorAuth.extra_data + cannot leak into a connector's metadata. + """ connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth( provider=provider, uid=uid ) - if connector_auth: - ( - connector_metadata, - _, - ) = connector_auth.get_and_refresh_tokens() - return connector_metadata # type: ignore + if not connector_auth: + return existing_metadata + refreshed_metadata, _ = connector_auth.get_and_refresh_tokens() + token_updates = { + key: refreshed_metadata[key] + for key in OAUTH_TOKEN_KEYS + if refreshed_metadata.get(key) is not None + } + return {**existing_metadata, **token_updates} diff --git a/backend/connector_v2/serializers.py b/backend/connector_v2/serializers.py index 5517bc5257..d639c923a1 100644 --- a/backend/connector_v2/serializers.py +++ b/backend/connector_v2/serializers.py @@ -2,12 +2,13 @@ from collections import OrderedDict from typing import Any +from connector_auth_v2.constants import OAUTH_TOKEN_KEYS from connector_auth_v2.models import ConnectorAuth from connector_auth_v2.pipeline.common import ConnectorAuthHelper from connector_processor.connector_processor import ConnectorProcessor from connector_processor.constants import ConnectorKeys -from connector_processor.exceptions import OAuthTimeOut -from rest_framework.serializers import CharField, SerializerMethodField +from connector_processor.exceptions import InvalidConnectorID, OAuthTimeOut +from rest_framework.serializers import CharField, SerializerMethodField, ValidationError from utils.fields import EncryptedBinaryFieldSerializer from utils.input_sanitizer import validate_name_field @@ -28,10 +29,57 @@ class ConnectorInstanceSerializer(AuditSerializer): class Meta: model = ConnectorInstance fields = "__all__" + extra_kwargs = {"connector_name": {"required": False}} def validate_connector_name(self, value: str) -> str: return validate_name_field(value, field_name="Connector name") + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: + """Backfill ``connector_name`` from the JSON schema default when absent. + + Defense-in-depth: the frontend RJSF form seeds ``connector_name`` from + the schema default, but callers (including staging OAuth flows) have + been observed to POST without it. If the connector schema declares a + default name, use it. Otherwise raise a 400 explicitly rather than + letting the missing value reach the DB and surface as an + ``IntegrityError`` (the model enforces ``null=False``). + + Skipped entirely on partial updates (PATCH): the existing DB row + already has a valid name, and backfilling would overwrite a + user-renamed connector with the schema default. + """ + attrs = super().validate(attrs) + if attrs.get(CIKey.CONNECTOR_NAME) or self.partial: + return attrs + + connector_id = attrs.get(CIKey.CONNECTOR_ID) + default_name = ( + self._get_schema_default_connector_name(connector_id) + if connector_id + else None + ) + if not default_name: + raise ValidationError({CIKey.CONNECTOR_NAME: "This field is required."}) + attrs[CIKey.CONNECTOR_NAME] = default_name + logger.info( + "Filled missing connector_name with schema default for %s", + connector_id, + ) + return attrs + + @staticmethod + def _get_schema_default_connector_name(connector_id: str) -> str | None: + try: + schema_details = ConnectorProcessor.get_json_schema(connector_id=connector_id) + except InvalidConnectorID: + return None + return ( + schema_details.get(ConnectorKeys.JSON_SCHEMA, {}) + .get("properties", {}) + .get("connectorName", {}) + .get("default") + ) + def save(self, **kwargs): # type: ignore user = self.context.get("request").user or None connector_id: str = kwargs[CIKey.CONNECTOR_ID] @@ -53,10 +101,22 @@ def save(self, **kwargs): # type: ignore oauth_credentials=kwargs[CIKey.CONNECTOR_METADATA], ) kwargs[CIKey.CONNECTOR_AUTH] = connector_oauth - ( - kwargs[CIKey.CONNECTOR_METADATA], - refresh_status, - ) = connector_oauth.get_and_refresh_tokens() + # Merge refreshed token fields (whitelist) back into this + # connector's metadata so ``super().save(**kwargs)`` does not + # overwrite the fresh token the sibling-loop just persisted. + # Whitelisting preserves per-connector form fields (site_url, + # drive_id) that must not be leaked across connectors sharing + # the same (provider, uid). + refreshed_metadata, _ = connector_oauth.get_and_refresh_tokens() + token_updates = { + key: refreshed_metadata[key] + for key in OAUTH_TOKEN_KEYS + if refreshed_metadata.get(key) is not None + } + kwargs[CIKey.CONNECTOR_METADATA] = { + **(kwargs.get(CIKey.CONNECTOR_METADATA) or {}), + **token_updates, + } except Exception as exc: logger.error( "Error while obtaining ConnectorAuth for connector id " diff --git a/backend/connector_v2/views.py b/backend/connector_v2/views.py index 09a33be983..37525e4d5a 100644 --- a/backend/connector_v2/views.py +++ b/backend/connector_v2/views.py @@ -90,12 +90,17 @@ def _get_connector_metadata(self, connector_id: str) -> dict[str, str] | None: # Only use OAuth flow if connector supports it AND oauth_key is provided if ConnectorInstance.supportsOAuth(connector_id=connector_id) and oauth_key: - connector_metadata = ConnectorAuthHelper.get_oauth_creds_from_cache( + oauth_tokens = ConnectorAuthHelper.get_oauth_creds_from_cache( cache_key=oauth_key, delete_key=False, # Don't delete yet - wait for successful operation ) - if connector_metadata is None: + if oauth_tokens is None: raise MissingParamException(param=ConnectorAuthKey.OAUTH_KEY) + # Preserve non-secret form fields (e.g. site_url connector Sharepoint) + form_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) or {} + if not isinstance(form_metadata, dict): + form_metadata = {} + connector_metadata = {**form_metadata, **oauth_tokens} else: connector_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) return connector_metadata diff --git a/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/sharepoint.py b/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/sharepoint.py index 3eacdac30d..e8acd44f31 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/sharepoint.py +++ b/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/sharepoint.py @@ -137,8 +137,8 @@ def _get_drive(self) -> Any: ctx = self._get_context() if self.drive_id: - # Specific drive by ID - self._drive = ctx.drives.get_by_id(self.drive_id) + # Specific drive by ID — EntityCollection uses bracket indexing. + self._drive = ctx.drives[self.drive_id] elif self.site_url and "sharepoint.com" in self.site_url.lower(): # SharePoint site - get default document library self._drive = self._get_sharepoint_site_drive(ctx) @@ -149,15 +149,15 @@ def _get_drive(self) -> Any: return self._drive def _get_sharepoint_site_drive(self, ctx: Any) -> Any: - """Get drive from SharePoint site URL.""" + """Get drive from SharePoint site URL. + + Uses the library's get_by_url, which maps an absolute site URL to the + Graph API's ``/sites/{hostname}:/{server-relative-path}`` addressing. + """ from urllib.parse import urlparse - parsed = urlparse(self.site_url) - # Extract site path from URL like - # https://tenant.sharepoint.com/sites/sitename - site_path = parsed.path.rstrip("/") - if site_path: - return ctx.sites.get_by_path(site_path).drive + if urlparse(self.site_url).path.strip("/"): + return ctx.sites.get_by_url(self.site_url).drive return ctx.sites.root.drive def _get_onedrive_drive(self, ctx: Any) -> Any: diff --git a/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/static/json_schema.json b/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/static/json_schema.json index d83f84c8ae..e1f0757e36 100644 --- a/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/static/json_schema.json +++ b/unstract/connectors/src/unstract/connectors/filesystems/sharepoint/static/json_schema.json @@ -22,6 +22,7 @@ "drive_id": { "type": "string", "title": "Drive ID", + "format": "password", "description": "Specific Drive/Document Library ID. Leave empty to use the default drive." }, "auth_type": { @@ -64,7 +65,7 @@ "type": "string", "title": "User Email", "format": "password", - "description": "User's email address. Required ONLY for OneDrive with Client Credentials (not needed for SharePoint).", + "description": "Required only to access OneDrive with Client Credentials (e.g., user@company.com). Leave empty when accessing a SharePoint site via Site URL.", "examples": [ "user@company.onmicrosoft.com", "user@company.com"