Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backend/connector_auth_v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ class SocialAuthConstants:

GOOGLE_OAUTH = "google-oauth2"
GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S"


# OAuth token-specific keys safe to merge across connectors sharing the same
# (provider, uid). Anything outside this set (form fields like site_url,
# drive_id, or provider-specific enrichment stored in ConnectorAuth.extra_data)
# must NOT leak between connectors.
OAUTH_TOKEN_KEYS: frozenset[str] = frozenset(
{
SocialAuthConstants.ACCESS_TOKEN,
SocialAuthConstants.REFRESH_TOKEN,
SocialAuthConstants.TOKEN_TYPE,
SocialAuthConstants.EXPIRES,
SocialAuthConstants.AUTH_TIME,
SocialAuthConstants.REFRESH_AFTER,
"expires_in",
"scope",
}
)
24 changes: 20 additions & 4 deletions backend/connector_auth_v2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from social_django.models import AbstractUserSocialAuth, DjangoStorage
from social_django.strategy import DjangoStrategy

from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.constants import OAUTH_TOKEN_KEYS, SocialAuthConstants
from connector_auth_v2.pipeline.google import GoogleAuthHelper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,11 +96,27 @@ def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bo
refreshed_token = True
related_connector_instances = self.connectorinstance_set.all()
for connector_instance in related_connector_instances:
connector_instance.connector_metadata = self.extra_data
# Whitelist-merge: only OAuth token keys flow from the shared
# extra_data into each sibling's metadata. Non-token keys (form
# fields like site_url, drive_id, or provider-enrichment stored
# in extra_data) must NOT leak between connectors sharing the
# same (provider, uid).
existing_metadata = connector_instance.connector_metadata or {}
token_updates = {
key: self.extra_data[key]
for key in OAUTH_TOKEN_KEYS
if self.extra_data.get(key) is not None
}
connector_instance.connector_metadata = {
**existing_metadata,
**token_updates,
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.
connector_instance.save()
logger.info(
f"Refreshed access token for connector {connector_instance.id}, "
f"provider: {self.provider}, uid: {self.uid}"
"Refreshed access token for connector %s, provider: %s, uid: %s",
connector_instance.id,
self.provider,
self.uid,
)

return self.extra_data, refreshed_token
Expand Down
31 changes: 21 additions & 10 deletions backend/connector_v2/fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from datetime import datetime

from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.constants import OAUTH_TOKEN_KEYS, SocialAuthConstants
from connector_auth_v2.models import ConnectorAuth
from django.db import models

Expand All @@ -25,17 +25,28 @@ def from_db_value(self, value, expression, connection): # type: ignore
refresh_after_str, SocialAuthConstants.REFRESH_AFTER_FORMAT
)
if datetime.now() > refresh_after:
metadata = self._refresh_tokens(provider, uid)
metadata = self._refresh_tokens(provider, uid, metadata)
return metadata

def _refresh_tokens(self, provider: str, uid: str) -> dict[str, str]:
"""Retrieves PSA object and refreshes the token if necessary."""
def _refresh_tokens(
self, provider: str, uid: str, existing_metadata: dict[str, str]
) -> dict[str, str]:
"""Retrieves PSA object and refreshes the token if necessary.

Whitelist-merges refreshed token fields over existing metadata so
per-instance form fields (e.g. site_url, drive_id) are preserved on
read and non-token keys from the shared ConnectorAuth.extra_data
cannot leak into a connector's metadata.
"""
connector_auth: ConnectorAuth = ConnectorAuth.get_social_auth(
provider=provider, uid=uid
)
if connector_auth:
(
connector_metadata,
_,
) = connector_auth.get_and_refresh_tokens()
return connector_metadata # type: ignore
if not connector_auth:
return existing_metadata
refreshed_metadata, _ = connector_auth.get_and_refresh_tokens()
token_updates = {
key: refreshed_metadata[key]
for key in OAUTH_TOKEN_KEYS
if refreshed_metadata.get(key) is not None
}
return {**existing_metadata, **token_updates}
72 changes: 66 additions & 6 deletions backend/connector_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from collections import OrderedDict
from typing import Any

from connector_auth_v2.constants import OAUTH_TOKEN_KEYS
from connector_auth_v2.models import ConnectorAuth
from connector_auth_v2.pipeline.common import ConnectorAuthHelper
from connector_processor.connector_processor import ConnectorProcessor
from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import OAuthTimeOut
from rest_framework.serializers import CharField, SerializerMethodField
from connector_processor.exceptions import InvalidConnectorID, OAuthTimeOut
from rest_framework.serializers import CharField, SerializerMethodField, ValidationError
from utils.fields import EncryptedBinaryFieldSerializer
from utils.input_sanitizer import validate_name_field

Expand All @@ -28,10 +29,57 @@ class ConnectorInstanceSerializer(AuditSerializer):
class Meta:
model = ConnectorInstance
fields = "__all__"
extra_kwargs = {"connector_name": {"required": False}}

def validate_connector_name(self, value: str) -> str:
return validate_name_field(value, field_name="Connector name")

def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Backfill ``connector_name`` from the JSON schema default when absent.

Defense-in-depth: the frontend RJSF form seeds ``connector_name`` from
the schema default, but callers (including staging OAuth flows) have
been observed to POST without it. If the connector schema declares a
default name, use it. Otherwise raise a 400 explicitly rather than
letting the missing value reach the DB and surface as an
``IntegrityError`` (the model enforces ``null=False``).

Skipped entirely on partial updates (PATCH): the existing DB row
already has a valid name, and backfilling would overwrite a
user-renamed connector with the schema default.
"""
attrs = super().validate(attrs)
if attrs.get(CIKey.CONNECTOR_NAME) or self.partial:
return attrs

connector_id = attrs.get(CIKey.CONNECTOR_ID)
default_name = (
self._get_schema_default_connector_name(connector_id)
if connector_id
else None
)
if not default_name:
raise ValidationError({CIKey.CONNECTOR_NAME: "This field is required."})
attrs[CIKey.CONNECTOR_NAME] = default_name
logger.info(
"Filled missing connector_name with schema default for %s",
connector_id,
)
return attrs
Comment thread
greptile-apps[bot] marked this conversation as resolved.

@staticmethod
def _get_schema_default_connector_name(connector_id: str) -> str | None:
try:
schema_details = ConnectorProcessor.get_json_schema(connector_id=connector_id)
except InvalidConnectorID:
return None
return (
schema_details.get(ConnectorKeys.JSON_SCHEMA, {})
.get("properties", {})
.get("connectorName", {})
.get("default")
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def save(self, **kwargs): # type: ignore
user = self.context.get("request").user or None
connector_id: str = kwargs[CIKey.CONNECTOR_ID]
Expand All @@ -53,10 +101,22 @@ def save(self, **kwargs): # type: ignore
oauth_credentials=kwargs[CIKey.CONNECTOR_METADATA],
)
kwargs[CIKey.CONNECTOR_AUTH] = connector_oauth
(
kwargs[CIKey.CONNECTOR_METADATA],
refresh_status,
) = connector_oauth.get_and_refresh_tokens()
# Merge refreshed token fields (whitelist) back into this
# connector's metadata so ``super().save(**kwargs)`` does not
# overwrite the fresh token the sibling-loop just persisted.
# Whitelisting preserves per-connector form fields (site_url,
# drive_id) that must not be leaked across connectors sharing
# the same (provider, uid).
refreshed_metadata, _ = connector_oauth.get_and_refresh_tokens()
token_updates = {
key: refreshed_metadata[key]
for key in OAUTH_TOKEN_KEYS
if refreshed_metadata.get(key) is not None
}
kwargs[CIKey.CONNECTOR_METADATA] = {
**(kwargs.get(CIKey.CONNECTOR_METADATA) or {}),
**token_updates,
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
except Exception as exc:
logger.error(
"Error while obtaining ConnectorAuth for connector id "
Expand Down
9 changes: 7 additions & 2 deletions backend/connector_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,17 @@ def _get_connector_metadata(self, connector_id: str) -> dict[str, str] | None:

# Only use OAuth flow if connector supports it AND oauth_key is provided
if ConnectorInstance.supportsOAuth(connector_id=connector_id) and oauth_key:
connector_metadata = ConnectorAuthHelper.get_oauth_creds_from_cache(
oauth_tokens = ConnectorAuthHelper.get_oauth_creds_from_cache(
cache_key=oauth_key,
delete_key=False, # Don't delete yet - wait for successful operation
)
if connector_metadata is None:
if oauth_tokens is None:
raise MissingParamException(param=ConnectorAuthKey.OAUTH_KEY)
# Preserve non-secret form fields (e.g. site_url connector Sharepoint)
form_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA) or {}
Comment thread
muhammad-ali-e marked this conversation as resolved.
if not isinstance(form_metadata, dict):
form_metadata = {}
connector_metadata = {**form_metadata, **oauth_tokens}
Comment thread
muhammad-ali-e marked this conversation as resolved.
Comment thread
greptile-apps[bot] marked this conversation as resolved.
else:
connector_metadata = self.request.data.get(CIKey.CONNECTOR_METADATA)
return connector_metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def _get_drive(self) -> Any:
ctx = self._get_context()

if self.drive_id:
# Specific drive by ID
self._drive = ctx.drives.get_by_id(self.drive_id)
# Specific drive by ID — EntityCollection uses bracket indexing.
self._drive = ctx.drives[self.drive_id]
Comment thread
muhammad-ali-e marked this conversation as resolved.
elif self.site_url and "sharepoint.com" in self.site_url.lower():
# SharePoint site - get default document library
self._drive = self._get_sharepoint_site_drive(ctx)
Expand All @@ -149,15 +149,15 @@ def _get_drive(self) -> Any:
return self._drive

def _get_sharepoint_site_drive(self, ctx: Any) -> Any:
"""Get drive from SharePoint site URL."""
"""Get drive from SharePoint site URL.

Uses the library's get_by_url, which maps an absolute site URL to the
Graph API's ``/sites/{hostname}:/{server-relative-path}`` addressing.
"""
from urllib.parse import urlparse

parsed = urlparse(self.site_url)
# Extract site path from URL like
# https://tenant.sharepoint.com/sites/sitename
site_path = parsed.path.rstrip("/")
if site_path:
return ctx.sites.get_by_path(site_path).drive
if urlparse(self.site_url).path.strip("/"):
return ctx.sites.get_by_url(self.site_url).drive
return ctx.sites.root.drive

def _get_onedrive_drive(self, ctx: Any) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"drive_id": {
"type": "string",
"title": "Drive ID",
"format": "password",
"description": "Specific Drive/Document Library ID. Leave empty to use the default drive."
},
"auth_type": {
Expand Down Expand Up @@ -64,7 +65,7 @@
"type": "string",
"title": "User Email",
"format": "password",
"description": "User's email address. Required ONLY for OneDrive with Client Credentials (not needed for SharePoint).",
"description": "Required only to access OneDrive with Client Credentials (e.g., user@company.com). Leave empty when accessing a SharePoint site via Site URL.",
"examples": [
"user@company.onmicrosoft.com",
"user@company.com"
Expand Down