Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable users to disable validation only on specific topics #759

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ Keys to take special care are the ones needed to configure Kafka and advertised_
- Name strategy to use when storing schemas from the kafka rest proxy service. You can opt between ``name_strategy`` , ``record_name`` and ``topic_record_name``
* - ``name_strategy_validation``
- ``true``
- If enabled, validate that given schema is registered under used name strategy when producing messages from Kafka Rest
- If enabled, validate that given schema is registered under the expected subjects requireds by the specified name strategy when producing messages from Kafka Rest. Otherwise no validation are performed
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved
* - ``master_election_strategy``
- ``lowest``
- Decides on what basis the Karapace cluster master is chosen (only relevant in a multi node setup)
Expand Down
12 changes: 11 additions & 1 deletion karapace/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from karapace.schema_models import SchemaVersion, TypedSchema
from karapace.schema_references import Reference, Referents
from karapace.typing import ResolvedVersion, SchemaId, Subject
from karapace.typing import ResolvedVersion, SchemaId, Subject, TopicName
from threading import Lock, RLock
from typing import Iterable, Sequence

Expand All @@ -32,6 +32,7 @@ def __init__(self) -> None:
self.schemas: dict[SchemaId, TypedSchema] = {}
self.schema_lock_thread = RLock()
self.referenced_by: dict[tuple[Subject, ResolvedVersion], Referents] = {}
self.topic_without_validation: set[TopicName] = set()

# Content based deduplication of schemas. This is used to reduce memory
# usage when the same schema is produce multiple times to the same or
Expand Down Expand Up @@ -229,6 +230,15 @@ def find_subject_schemas(self, *, subject: Subject, include_deleted: bool) -> di
if schema_version.deleted is False
}

def is_topic_requiring_validation(self, *, topic_name: TopicName) -> bool:
return topic_name not in self.topic_without_validation

def override_topic_validation(self, *, topic_name: TopicName, skip_validation: bool) -> None:
if skip_validation:
self.topic_without_validation.add(topic_name)
else:
self.topic_without_validation.discard(topic_name)

def delete_subject(self, *, subject: Subject, version: ResolvedVersion) -> None:
with self.schema_lock_thread:
for schema_version in self.subjects[subject].schemas.values():
Expand Down
105 changes: 83 additions & 22 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaConnectionError
from binascii import Error as B64DecodeError
Expand Down Expand Up @@ -31,20 +33,21 @@
get_subject_name,
InvalidMessageSchema,
InvalidPayload,
SchemaRegistryClient,
SchemaRegistrySerializer,
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType, TopicName
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Final, MutableMapping, NewType

import asyncio
import base64
import datetime
import logging
import time

SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value]
SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value_]
PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"}
RECORD_CODES = [42201, 42202]
KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"}
Expand All @@ -66,10 +69,10 @@ def __init__(self, config: Config) -> None:
super().__init__(config=config)
self._add_kafka_rest_routes()
self.serializer = SchemaRegistrySerializer(config=config)
self.proxies: Dict[str, "UserRestProxy"] = {}
self.proxies: dict[str, UserRestProxy] = {}
self._proxy_lock = asyncio.Lock()
log.info("REST proxy starting with (delegated authorization=%s)", self.config.get("rest_authorization", False))
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None
self._idle_proxy_janitor_task: asyncio.Task | None = None

async def close(self) -> None:
if self._idle_proxy_janitor_task is not None:
Expand Down Expand Up @@ -415,13 +418,56 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq
await proxy.topic_publish(topic, content_type, request=request)


LastTimeCheck = NewType("LastTimeCheck", float)

DEFAULT_CACHE_INTERVAL_NS: Final = 120 * 1_000_000_000 # 120 seconds


class ValidationCheckWrapper:
def __init__(
self,
registry_client: SchemaRegistryClient,
topic_name: TopicName,
cache_interval_ns: float = DEFAULT_CACHE_INTERVAL_NS,
):
self._last_check = 0
# by default if not specified otherwise, let's be conservative
self._require_validation = True
self._topic_name = topic_name
self._registry_client = registry_client
self._cache_interval_ns = cache_interval_ns

async def _query_registry(self) -> bool:
require_validation = await self._registry_client.topic_require_validation(self._topic_name)
return require_validation

async def require_validation(self) -> bool:
if (time.monotonic_ns() - self._last_check) > self._cache_interval_ns:
self._require_validation = await self._query_registry()
self._last_check = time.monotonic_ns()

return self._require_validation

@classmethod
async def construct_new(
cls,
registry_client: SchemaRegistryClient,
topic_name: TopicName,
cache_interval_ns: float = DEFAULT_CACHE_INTERVAL_NS,
) -> ValidationCheckWrapper:
validation_checker = cls(registry_client, topic_name, cache_interval_ns)
validation_checker._require_validation = await validation_checker._query_registry()
validation_checker._last_check = time.monotonic_ns()
return validation_checker


class UserRestProxy:
def __init__(
self,
config: Config,
kafka_timeout: int,
serializer: SchemaRegistrySerializer,
auth_expiry: Optional[datetime.datetime] = None,
auth_expiry: datetime.datetime | None = None,
):
self.config = config
self.kafka_timeout = kafka_timeout
Expand All @@ -440,8 +486,18 @@ def __init__(
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
self._async_producer: AIOKafkaProducer | None = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])
self.topic_validation: MutableMapping[TopicName,] = {}

async def is_validation_required(self, topic_name: TopicName) -> bool:
if topic_name not in self.topic_validation:
self.topic_validation[topic_name] = await ValidationCheckWrapper.construct_new(
self.serializer.registry_client,
topic_name,
)

return await self.topic_validation[topic_name].require_validation()

def __str__(self) -> str:
return f"UserRestProxy(username={self.config['sasl_plain_username']})"
Expand Down Expand Up @@ -601,7 +657,7 @@ async def get_topic_config(self, topic: str) -> dict:
async with self.admin_lock:
return self.admin_client.get_topic_config(topic)

async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
async def cluster_metadata(self, topics: list[str] | None = None) -> dict:
async with self.admin_lock:
if self._metadata_birth is None or time.monotonic() - self._metadata_birth > self.metadata_max_age:
self._cluster_metadata = None
Expand Down Expand Up @@ -671,7 +727,7 @@ async def aclose(self) -> None:
self.admin_client = None
self.consumer_manager = None

async def publish(self, topic: str, partition_id: Optional[str], content_type: str, request: HTTPRequest) -> None:
async def publish(self, topic: str, partition_id: str | None, content_type: str, request: HTTPRequest) -> None:
formats: dict = request.content_type
data: dict = request.json
_ = await self.get_topic_info(topic, content_type)
Expand Down Expand Up @@ -769,7 +825,7 @@ async def get_schema_id(
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
schema_id: SchemaId | None = (
SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None
)
schema_str = data.get(f"{subject_type}_schema")
Expand All @@ -788,8 +844,9 @@ async def get_schema_id(
)
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:
is_validation_required = await self.is_validation_required(topic_name=TopicName(topic))

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
def subject_not_included(schema: TypedSchema, subjects: list[Subject]) -> bool:
subject = get_subject_name(topic, schema, subject_type, self.naming_strategy)
return subject not in subjects

Expand All @@ -798,14 +855,18 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
need_new_call=subject_not_included,
)

if self.config["name_strategy_validation"] and subject_not_included(parsed_schema, valid_subjects):
if (
self.config["name_strategy_validation"]
and is_validation_required
and subject_not_included(parsed_schema, valid_subjects)
):
raise InvalidSchema()

return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
self, schema_id: SchemaId, *, need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None
) -> tuple[TypedSchema, list[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
Expand Down Expand Up @@ -896,10 +957,10 @@ async def _prepare_records(
content_type: str,
data: dict,
ser_format: str,
key_schema_id: Optional[int],
value_schema_id: Optional[int],
default_partition: Optional[int] = None,
) -> List[Tuple]:
key_schema_id: int | None,
value_schema_id: int | None,
default_partition: int | None = None,
) -> list[tuple]:
prepared_records = []
for record in data["records"]:
key = record.get("key")
Expand Down Expand Up @@ -950,8 +1011,8 @@ async def serialize(
self,
content_type: str,
obj=None,
ser_format: Optional[str] = None,
schema_id: Optional[int] = None,
ser_format: str | None = None,
schema_id: int | None = None,
) -> bytes:
if not obj:
return b""
Expand All @@ -975,7 +1036,7 @@ async def serialize(
return await self.schema_serialize(obj, schema_id)
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
async def schema_serialize(self, obj: dict, schema_id: int | None) -> bytes:
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_
Expand Down Expand Up @@ -1038,7 +1099,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
sub_code=RESTErrorCodes.INVALID_DATA.value,
)

async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
async def produce_messages(self, *, topic: str, prepared_records: list) -> list:
producer = await self._maybe_create_async_producer()

produce_futures = []
Expand Down
3 changes: 3 additions & 0 deletions karapace/key_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def format_key(
corrected_key["subject"] = key["subject"]
if "version" in key:
corrected_key["version"] = key["version"]
if "topic" in key:
corrected_key["topic"] = key["topic"]

# Magic is the last element
corrected_key["magic"] = key["magic"]
return json_encode(corrected_key, binary=True, sort_keys=False, compact=True)
Expand Down
14 changes: 11 additions & 3 deletions karapace/schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from avro.schema import Schema as AvroSchema
from contextlib import closing, ExitStack
from enum import Enum
from jsonschema.validators import Draft7Validator
from kafka import KafkaConsumer, TopicPartition
from kafka.errors import (
Expand All @@ -32,7 +31,7 @@
from karapace.schema_models import parse_protobuf_schema_definition, SchemaType, TypedSchema, ValidatedTypedSchema
from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping, Referents
from karapace.statsd import StatsClient
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, StrEnum, Subject, TopicName
from karapace.utils import json_decode, JSONDecodeError, KarapaceKafkaClient
from threading import Event, Thread
from typing import Final, Mapping, Sequence
Expand All @@ -59,10 +58,11 @@
METRIC_SUBJECT_DATA_SCHEMA_VERSIONS_GAUGE: Final = "karapace_schema_reader_subject_data_schema_versions"


class MessageType(Enum):
class MessageType(StrEnum):
config = "CONFIG"
schema = "SCHEMA"
delete_subject = "DELETE_SUBJECT"
schema_validation = "SCHEMA_VALIDATION"
no_operation = "NOOP"


Expand Down Expand Up @@ -429,6 +429,12 @@ def _handle_msg_delete_subject(self, key: dict, value: dict | None) -> None: #
LOG.info("Deleting subject: %r, value: %r", subject, value)
self.database.delete_subject(subject=subject, version=version)

def _handle_msg_schema_validation(self, key: dict, value: dict | None) -> None:
assert isinstance(value, dict)
topic = TopicName(key["topic"])
skip_validation = bool(value["skip_validation"])
self.database.override_topic_validation(topic_name=topic, skip_validation=skip_validation)

def _handle_msg_schema_hard_delete(self, key: dict) -> None:
subject, version = key["subject"], key["version"]

Expand Down Expand Up @@ -532,6 +538,8 @@ def handle_msg(self, key: dict, value: dict | None) -> None:
self._handle_msg_schema(key, value)
elif message_type == MessageType.delete_subject:
self._handle_msg_delete_subject(key, value)
elif message_type == MessageType.schema_validation:
self._handle_msg_schema_validation(key, value)
elif message_type == MessageType.no_operation:
pass
except ValueError:
Expand Down
17 changes: 15 additions & 2 deletions karapace/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from karapace.messaging import KarapaceProducer
from karapace.offset_watcher import OffsetWatcher
from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema
from karapace.schema_reader import KafkaSchemaReader
from karapace.schema_reader import KafkaSchemaReader, MessageType
from karapace.schema_references import LatestVersionReference, Reference
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, Version
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, TopicName, Version
from typing import Mapping, Sequence

import asyncio
Expand Down Expand Up @@ -466,6 +466,19 @@ def send_schema_message(
value = None
self.producer.send_message(key=key, value=value)

def is_topic_requiring_validation(self, *, topic_name: TopicName) -> bool:
return self.database.is_topic_requiring_validation(topic_name=topic_name)

def update_require_validation_for_topic(
self,
*,
topic_name: TopicName,
skip_validation: bool,
) -> None:
key = {"topic": topic_name, "keytype": str(MessageType.schema_validation), "magic": 0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formatting of the key needs to be aligned in the key_format.py. This record needs the canonical format, topic is an unknown key in data currently.
How does the Confluent Schema Registry react to custom message like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still a valid point, taking a look tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline, probably solved. Waiting an update on that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that I need to add the topic to the key, otherwise the compaction could delete the topic information, this doesn't change the behaviour of confluent schema registry that is simply to skip the messages he cannot understand (a behaviour that we also do while receiving messages that aren't parsable by us)

value = {"skip_validation": skip_validation}
self.producer.send_message(key=key, value=value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on skip_validation this could also send tombstone record instead of toggling between true/false.

Copy link
Contributor Author

@eliax1996 eliax1996 Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, otherwise how can you enable again the validation for a certain topic (after you disable it)?


def send_config_message(self, compatibility_level: CompatibilityModes, subject: Subject | None = None) -> None:
key = {"subject": subject, "magic": 0, "keytype": "CONFIG"}
value = {"compatibilityLevel": compatibility_level.value}
Expand Down
Loading