Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions scripts/post_generate_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,260 @@ def _first_subscript_arg(node: ast.Subscript) -> ast.AST | None:
return slice_node


# ---------------------------------------------------------------------------
# #624: widen documented extension-point list[X] fields to Sequence[X].
#
# Adopters who follow Critical Pattern #1 (subclass a library response type
# and override the parent's list field with a more specific element type)
# hit `# type: ignore[assignment]` on every override under mypy --strict —
# list is invariant in its element type. Sequence is covariant, so a
# Sequence[Parent] parent permits list[Child] override cleanly.
#
# Scope is intentionally narrow: only fields the SDK documents as
# extension points (response payloads adopters routinely subclass, plus
# request bodies that compose extendable sub-records like packages and
# creatives). Internal scalars stay as list.
#
# Allowlist format: (class_name, field_name). datamodel-codegen emits
# bundled response files that each inline copies of subordinate types
# (Placement, TargetingOverlay, etc.); the rewriter walks every generated
# .py file and applies the substitution to every emission of the named
# (class, field) pair so all copies stay consistent.

_SEQUENCE_EXTENSION_POINTS: list[tuple[str, str]] = [
# Response payloads adopters subclass to add internal-only fields.
# `UpdateMediaBuySuccessResponse` is the success variant of the
# `UpdateMediaBuyResponse` discriminated union — emitted as
# `UpdateMediaBuyResponse1` (v3.0) and `UpdateMediaBuyResponse3`
# (v3.0.6 bundled).
("UpdateMediaBuyResponse1", "affected_packages"),
("UpdateMediaBuyResponse3", "affected_packages"),
("GetMediaBuyDeliveryResponse", "media_buy_deliveries"),
("GetCreativeDeliveryResponse", "creatives"),
("Signal", "deployments"),
("GetSignalsResponse", "signals"),
("GetMediaBuysResponse", "media_buys"),
("ListCreativesResponse", "creatives"),
# Request bodies that carry extendable sub-records — adopters subclass
# the inner record type and need to override the list element type.
("PackageRequest", "creatives"),
("CreateMediaBuyRequest", "packages"),
("UpdateMediaBuyRequest", "packages"),
# Cross-cutting record types referenced from multiple responses; each
# bundled response file inlines its own copy. The walker rewrites
# every emission.
("Placement", "format_ids"),
("TargetingOverlay", "geo_countries_exclude"),
("TargetingOverlay", "geo_regions_exclude"),
("TargetingOverlay", "geo_metros_exclude"),
("TargetingOverlay", "geo_postal_areas_exclude"),
]


def widen_extension_point_lists_to_sequence():
"""Rewrite ``list[X]`` to ``Sequence[X]`` on documented extension-point fields.

Walks every generated ``.py`` file under :data:`OUTPUT_DIR`. For each
file, applies every ``(class, field)`` pair in
:data:`_SEQUENCE_EXTENSION_POINTS` that matches a class declaration
in that file. The same ``(class, field)`` pair commonly appears in
multiple files because bundled response emission inlines copies of
subordinate types — every emission is rewritten so all paths stay
consistent. Each rewritten file gets ``from collections.abc import
Sequence`` added if it isn't already present.

Pairs that produce zero rewrites across the whole tree emit a WARN
so allowlist drift surfaces fast (a renamed field or removed class
means the override pattern this entry was protecting no longer
exists).

See `adcp-client-python#624 <https://github.com/adcontextprotocol/adcp-client-python/issues/624>`_
for the design rationale and the spike that validated the Pydantic
plugin accepts ``Sequence[Parent]`` parent + ``list[Child]`` child
override under mypy --strict.
"""
print("Widening extension-point list[X] fields to Sequence[X] (#624)...")

# Track total rewrites per (class, field) — a pair with zero hits is
# a stale allowlist entry and surfaces as a WARN.
# Track per-pair state across all files:
# rewrites: how many list[X] sites were rewritten this run
# already_widened: how many sites are already in Sequence[X] form
# A pair with rewrites == 0 AND already_widened == 0 is genuinely stale
# (field renamed/removed) and warrants a WARN. A pair with already_widened
# > 0 is silent — that's the steady-state idempotent run.
rewrites_per_pair: dict[tuple[str, str], int] = {pair: 0 for pair in _SEQUENCE_EXTENSION_POINTS}
already_per_pair: dict[tuple[str, str], int] = {pair: 0 for pair in _SEQUENCE_EXTENSION_POINTS}
files_touched = 0
total_widened = 0

for file_path in sorted(OUTPUT_DIR.rglob("*.py")):
original = file_path.read_text()
content = original
widened_in_file = 0

for class_name, field_name in _SEQUENCE_EXTENSION_POINTS:
# Quick filter — skip files that don't declare this class.
if f"class {class_name}(" not in content and f"class {class_name}:" not in content:
continue
new_content, did_widen = _widen_field_annotation(content, class_name, field_name)
if did_widen:
content = new_content
widened_in_file += 1
rewrites_per_pair[(class_name, field_name)] += 1
elif _field_already_widened(content, class_name, field_name):
already_per_pair[(class_name, field_name)] += 1

if widened_in_file == 0:
continue

content = _ensure_sequence_import(content)
file_path.write_text(content)
files_touched += 1
total_widened += widened_in_file
print(f" ✓ {file_path.relative_to(OUTPUT_DIR)}: widened {widened_in_file} field(s)")

stale = [
pair
for pair in _SEQUENCE_EXTENSION_POINTS
if rewrites_per_pair[pair] == 0 and already_per_pair[pair] == 0
]
for class_name, field_name in stale:
print(
f" WARN: {class_name}.{field_name} — neither list[X] nor Sequence[X] "
"found in any generated file (field renamed or removed?)"
)

if total_widened == 0:
print(" No extension-point fields to widen")
else:
print(
f" ✓ Widened {total_widened} extension-point field(s) "
f"across {files_touched} file(s)"
)


def _widen_field_annotation(content: str, class_name: str, field_name: str) -> tuple[str, bool]:
"""Rewrite ``list[X]`` → ``Sequence[X]`` in one field's annotation.

Locates ``class {class_name}(...):`` then walks forward to the first
`` {field_name}:`` line at class-body indentation, **bounded to the
target class** so a same-named field on a later class in the same
file cannot mis-match. Within the AnnAssign's annotation block (which
may span multiple lines for ``Annotated[..., Field(...)]``), replaces
the first ``list[`` with ``Sequence[``. Idempotent — a second pass
over already-widened content is a no-op.
"""
# Anchor on the class definition.
class_pattern = re.compile(rf"^class {re.escape(class_name)}\b", re.MULTILINE)
class_match = class_pattern.search(content)
if class_match is None:
return content, False

# Bound the search region to the current class body. Scanning past the
# next `^class ` would let `re.search` mis-target a same-named field
# on a sibling class in the same file (the lookahead in
# field_start_pattern terminates a *match*, but `re.search` is free to
# scan past the first class's boundary looking for a hit).
class_body_start = class_match.end()
next_class = re.compile(r"^class ", re.MULTILINE).search(content, class_body_start)
region_end = next_class.start() if next_class is not None else len(content)
region = content[class_body_start:region_end]

# The annotation block runs from the field name to the next class-body
# statement at 4-space indentation (next field, model_config, or method).
field_start_pattern = re.compile(
rf"^( {re.escape(field_name)}: )(.*?)(?=^ [a-zA-Z_]|\Z)",
re.MULTILINE | re.DOTALL,
)
field_match = field_start_pattern.search(region)
if field_match is None:
return content, False

annotation_block = field_match.group(2)
# Replace the first list[ inside the annotation only. Generated
# annotations always have `list[X]` as the outer container; the
# narrow scope of the allowlist (no `dict[str, list[X]]` entries)
# makes this safe in practice. If a future entry has nested list,
# this needs to anchor on the outer container explicitly.
new_annotation = re.sub(r"\blist\[", "Sequence[", annotation_block, count=1)
if new_annotation == annotation_block:
return content, False

# Stitch back. .start()/.end() are relative to `region`; convert to
# absolute offsets in `content`.
abs_start = class_body_start + field_match.start(2)
abs_end = class_body_start + field_match.end(2)
new_content = content[:abs_start] + new_annotation + content[abs_end:]
return new_content, True


def _field_already_widened(content: str, class_name: str, field_name: str) -> bool:
"""Return True when the named field's annotation is already ``Sequence[X]``.

Used to silence the WARN on idempotent re-runs: a pair that's already
widened is the steady state, not allowlist drift.
"""
class_match = re.search(rf"^class {re.escape(class_name)}\b", content, re.MULTILINE)
if class_match is None:
return False
class_body_start = class_match.end()
next_class = re.compile(r"^class ", re.MULTILINE).search(content, class_body_start)
region_end = next_class.start() if next_class is not None else len(content)
region = content[class_body_start:region_end]
field_match = re.search(
rf"^( {re.escape(field_name)}: )(.*?)(?=^ [a-zA-Z_]|\Z)",
region,
re.MULTILINE | re.DOTALL,
)
if field_match is None:
return False
return "Sequence[" in field_match.group(2)


def _ensure_sequence_import(content: str) -> str:
"""Add ``from collections.abc import Sequence`` if not already present.

Inserts after the ``from __future__ import annotations`` line so the
import sits with sibling stdlib imports rather than landing at the top
of the file.
"""
if "from collections.abc import Sequence" in content:
return content
# If `collections.abc` is already imported, extend the import line.
extend_pattern = re.compile(r"^from collections\.abc import ([^\n]+)$", re.MULTILINE)
match = extend_pattern.search(content)
if match is not None:
existing = match.group(1)
# Maintain alphabetical order if the existing import is sorted.
names = sorted({*[n.strip() for n in existing.split(",")], "Sequence"})
new_line = f"from collections.abc import {', '.join(names)}"
return content[: match.start()] + new_line + content[match.end() :]

# Otherwise insert after the typing imports block. Codegen always emits
# ``from typing import Annotated`` near the top, so anchor on it.
typing_pattern = re.compile(r"^from typing import [^\n]+$", re.MULTILINE)
match = typing_pattern.search(content)
if match is not None:
return (
content[: match.end()]
+ "\nfrom collections.abc import Sequence"
+ content[match.end() :]
)

# Fallback: prepend after the `from __future__` line.
future_pattern = re.compile(r"^from __future__ import annotations$", re.MULTILINE)
match = future_pattern.search(content)
if match is not None:
return (
content[: match.end()]
+ "\n\nfrom collections.abc import Sequence"
+ content[match.end() :]
)

return "from collections.abc import Sequence\n\n" + content


def main():
"""Apply all post-generation fixes."""
print("Applying post-generation fixes...")
Expand All @@ -816,6 +1070,7 @@ def main():
fix_reuse_model_discriminator_bug()
restore_format_category_deprecation_shim()
inject_literal_discriminator_defaults()
widen_extension_point_lists_to_sequence()

print("\n✓ Post-generation fixes complete\n")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from enum import Enum, IntEnum
from typing import Annotated, Any, Literal
from collections.abc import Sequence

from adcp.types.base import AdCPBaseModel
from pydantic import AnyUrl, AwareDatetime, ConfigDict, Field, RootModel, StringConstraints
Expand Down Expand Up @@ -3612,7 +3613,7 @@ class GetCreativeDeliveryResponse(AdCPBaseModel):
]
reporting_period: Annotated[ReportingPeriod, Field(description='Date range for the report.')]
creatives: Annotated[
list[Creative], Field(description='Creative delivery data with variant breakdowns')
Sequence[Creative], Field(description='Creative delivery data with variant breakdowns')
]
pagination: Annotated[
Pagination | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from enum import Enum, IntEnum
from typing import Annotated, Any, Literal
from collections.abc import Sequence

from adcp.types.base import AdCPBaseModel
from pydantic import AnyUrl, AwareDatetime, ConfigDict, EmailStr, Field, RootModel, StringConstraints
Expand Down Expand Up @@ -3810,7 +3811,7 @@ class ListCreativesResponse(AdCPBaseModel):
),
]
creatives: Annotated[
list[Creative], Field(description='Array of creative assets matching the query')
Sequence[Creative], Field(description='Array of creative assets matching the query')
]
format_summary: Annotated[
dict[Annotated[str, StringConstraints(pattern=r'^[a-zA-Z0-9_-]+$')], int] | None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from enum import Enum, IntEnum
from typing import Annotated, Any, Literal
from collections.abc import Sequence

from adcp.types.base import AdCPBaseModel
from pydantic import AnyUrl, AwareDatetime, ConfigDict, EmailStr, Field, RootModel, StringConstraints
Expand Down Expand Up @@ -2576,7 +2577,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_countries_exclude: Annotated[
list[GeoCountriesExcludeItem] | None,
Sequence[GeoCountriesExcludeItem] | None,
Field(
description="Exclude specific countries from delivery. ISO 3166-1 alpha-2 codes (e.g., 'US', 'GB', 'DE').",
min_length=1,
Expand All @@ -2590,7 +2591,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_regions_exclude: Annotated[
list[GeoRegionsExcludeItem] | None,
Sequence[GeoRegionsExcludeItem] | None,
Field(
description="Exclude specific regions/states from delivery. ISO 3166-2 subdivision codes (e.g., 'US-CA', 'GB-SCT').",
min_length=1,
Expand All @@ -2604,7 +2605,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_metros_exclude: Annotated[
list[GeoMetrosExcludeItem] | None,
Sequence[GeoMetrosExcludeItem] | None,
Field(
description='Exclude specific metro areas from delivery. Each entry specifies the classification system and excluded values. Seller must declare supported systems in get_adcp_capabilities.',
min_length=1,
Expand All @@ -2618,7 +2619,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_postal_areas_exclude: Annotated[
list[GeoPostalAreasExcludeItem] | None,
Sequence[GeoPostalAreasExcludeItem] | None,
Field(
description='Exclude specific postal areas from delivery. Each entry specifies the postal system and excluded values. Seller must declare supported systems in get_adcp_capabilities.',
min_length=1,
Expand Down Expand Up @@ -5030,7 +5031,7 @@ class CreateMediaBuyRequest(AdCPBaseModel):
),
] = None
packages: Annotated[
list[Package] | None,
Sequence[Package] | None,
Field(
description="Array of package configurations. Required when not using proposal_id. When executing a proposal, this can be omitted and packages will be derived from the proposal's allocations.",
min_length=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from enum import Enum
from typing import Annotated, Any, Literal
from collections.abc import Sequence

from adcp.types.base import AdCPBaseModel
from pydantic import AnyUrl, AwareDatetime, ConfigDict, EmailStr, Field, RootModel
Expand Down Expand Up @@ -2540,7 +2541,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_countries_exclude: Annotated[
list[GeoCountriesExcludeItem] | None,
Sequence[GeoCountriesExcludeItem] | None,
Field(
description="Exclude specific countries from delivery. ISO 3166-1 alpha-2 codes (e.g., 'US', 'GB', 'DE').",
min_length=1,
Expand All @@ -2554,7 +2555,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_regions_exclude: Annotated[
list[GeoRegionsExcludeItem] | None,
Sequence[GeoRegionsExcludeItem] | None,
Field(
description="Exclude specific regions/states from delivery. ISO 3166-2 subdivision codes (e.g., 'US-CA', 'GB-SCT').",
min_length=1,
Expand All @@ -2568,7 +2569,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_metros_exclude: Annotated[
list[GeoMetrosExcludeItem] | None,
Sequence[GeoMetrosExcludeItem] | None,
Field(
description='Exclude specific metro areas from delivery. Each entry specifies the classification system and excluded values. Seller must declare supported systems in get_adcp_capabilities.',
min_length=1,
Expand All @@ -2582,7 +2583,7 @@ class TargetingOverlay(AdCPBaseModel):
),
] = None
geo_postal_areas_exclude: Annotated[
list[GeoPostalAreasExcludeItem] | None,
Sequence[GeoPostalAreasExcludeItem] | None,
Field(
description='Exclude specific postal areas from delivery. Each entry specifies the postal system and excluded values. Seller must declare supported systems in get_adcp_capabilities.',
min_length=1,
Expand Down
Loading
Loading