Skip to content

Commit

Permalink
Allow schema migration of block documents during Block.save (#8056)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored and github-actions[bot] committed Jan 6, 2023
1 parent 8cb84c1 commit 90a94af
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/prefect/client/orion.py
Expand Up @@ -1007,7 +1007,7 @@ async def update_block_document(
json=block_document.dict(
json_compatible=True,
exclude_unset=True,
include={"data", "merge_existing_data"},
include={"data", "merge_existing_data", "block_schema_id"},
include_secrets=True,
),
)
Expand Down
7 changes: 5 additions & 2 deletions src/prefect/orion/api/block_documents.py
Expand Up @@ -126,8 +126,11 @@ async def update_block_document_data(
block_document_id=block_document_id,
block_document=block_document,
)
except ValueError:
raise HTTPException(status.HTTP_400_BAD_REQUEST)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(exc),
)

if not result:
raise HTTPException(
Expand Down
30 changes: 29 additions & 1 deletion src/prefect/orion/models/block_documents.py
Expand Up @@ -492,11 +492,39 @@ async def update_block_document(
new_block_document_references,
) = _separate_block_references_from_data(update_values["data"])

# encrypt the data
# encrypt the data and write updated data to the block document
await current_block_document.encrypt_data(
session=session, data=block_document_data_without_refs
)

# `proposed_block_schema` is always the same as the schema on the client-side
# Block class that is calling `save`, which may or may not be the same schema
# as the one on the saved block document
proposed_block_schema_id = block_document.block_schema_id

# if a new schema is proposed, update the block schema id for the block document
if (
proposed_block_schema_id is not None
and proposed_block_schema_id != current_block_document.block_schema_id
):
proposed_block_schema = await session.get(
db.BlockSchema, proposed_block_schema_id
)

# make sure the proposed schema is of the same block type as the current document
if (
proposed_block_schema.block_type_id
!= current_block_document.block_type_id
):
raise ValueError(
"Must migrate block document to a block schema of the same block type."
)
await session.execute(
sa.update(db.BlockDocument)
.where(db.BlockDocument.id == block_document_id)
.values(block_schema_id=proposed_block_schema_id)
)

unchanged_block_document_references = []
for secret_key, reference_block_document_id in new_block_document_references:
matching_current_block_document_reference = _find_block_document_reference(
Expand Down
3 changes: 3 additions & 0 deletions src/prefect/orion/schemas/actions.py
Expand Up @@ -314,6 +314,9 @@ def validate_name_is_present_if_not_anonymous(cls, values):
class BlockDocumentUpdate(ActionBaseModel):
"""Data used by the Orion API to update a block document."""

block_schema_id: Optional[UUID] = Field(
default=None, description="A block schema ID"
)
data: dict = FieldFrom(schemas.core.BlockDocument)
merge_existing_data: bool = True

Expand Down
100 changes: 73 additions & 27 deletions tests/blocks/test_core.py
Expand Up @@ -8,7 +8,6 @@
import pytest
from packaging.version import Version
from pydantic import BaseModel, Field, SecretBytes, SecretStr
from pydantic.fields import ModelField

import prefect
from prefect.blocks.core import Block, InvalidBlockRegistration
Expand All @@ -19,7 +18,7 @@
from prefect.orion import models
from prefect.orion.schemas.actions import BlockDocumentCreate
from prefect.orion.schemas.core import DEFAULT_BLOCK_SCHEMA_VERSION
from prefect.utilities.dispatch import get_registry_for_type, lookup_type, register_type
from prefect.utilities.dispatch import lookup_type, register_type
from prefect.utilities.names import obfuscate_string


Expand Down Expand Up @@ -2076,50 +2075,52 @@ class Config:


class TestBlockSchemaMigration:
@pytest.fixture
def new_field(self):
return {
"y": ModelField.infer(
name="y",
value=...,
annotation=int,
class_validators=None,
config=Block.__config__,
)
}

def test_schema_mismatch_with_validation_raises(self, new_field):
def test_schema_mismatch_with_validation_raises(self):
class A(Block):
_block_type_name = "a"
_block_type_slug = "a"
x: int = 1

a = A()

a.save("test")

A.__fields__.update(new_field) # simulate a schema change
with pytest.warns(UserWarning, match="matches existing registered type 'A'"):

class A_Alias(Block):
_block_type_name = "a"
_block_type_slug = "a"
x: int = 1
y: int

with pytest.raises(
RuntimeError, match="try loading again with `validate=False`"
):
A.load("test")
A_Alias.load("test")

def test_add_field_to_schema_with_skip_validation(self, new_field):
def test_add_field_to_schema_partial_load_with_skip_validation(self):
class A(Block):
x: int = 1

a = A()

a.save("test")

A.__fields__.update(new_field) # simulate a schema change
with pytest.warns(UserWarning, match="matches existing registered type 'A'"):

class A_Alias(Block):
_block_type_name = "a"
_block_type_slug = "a"
x: int = 1
y: int

with pytest.warns(UserWarning, match="Could not fully load"):
a = A.load("test", validate=False)
a = A_Alias.load("test", validate=False)

assert a.x == 1
assert a.y == None

def test_rm_field_from_schema_loads_with_validation(self, new_field):
def test_rm_field_from_schema_loads_with_validation(self):
class Foo(Block):
_block_type_name = "foo"
_block_type_slug = "foo"
Expand All @@ -2130,12 +2131,12 @@ class Foo(Block):

foo.save("xy")

get_registry_for_type(Block).pop("foo")
with pytest.warns(UserWarning, match="matches existing registered type 'Foo'"):

class Foo_Alias(Block):
_block_type_name = "foo"
_block_type_slug = "foo"
x: int = 1
class Foo_Alias(Block):
_block_type_name = "foo"
_block_type_slug = "foo"
x: int = 1

foo_alias = Foo_Alias.load("xy")

Expand All @@ -2146,7 +2147,7 @@ class Foo_Alias(Block):
# with pytest.raises(AttributeError):
# foo_alias.y

def test_load_with_skip_validation_keeps_metadata(self, new_field):
def test_load_with_skip_validation_keeps_metadata(self):
class Bar(Block):
x: int = 1

Expand All @@ -2157,3 +2158,48 @@ class Bar(Block):
bar_new = Bar.load("test", validate=False)

assert bar.dict() == bar_new.dict()

async def test_save_new_schema_with_overwrite(self, orion_client):
class Baz(Block):
_block_type_name = "baz"
_block_type_slug = "baz"
x: int = 1

baz = Baz()

await baz.save("test")

block_document = await orion_client.read_block_document_by_name(
name="test", block_type_slug="baz"
)
old_schema_id = block_document.block_schema_id

with pytest.warns(UserWarning, match="matches existing registered type 'Baz'"):

class Baz_Alias(Block):
_block_type_name = "baz"
_block_type_slug = "baz"
x: int = 1
y: int = 2

baz_alias = await Baz_Alias.load("test", validate=False)

await baz_alias.save("test", overwrite=True)

baz_alias_RELOADED = await Baz_Alias.load("test")

assert baz_alias_RELOADED.x == 1
assert baz_alias_RELOADED.y == 2

new_schema_id = baz_alias._block_schema_id

# new local schema ID should be different because field added
assert old_schema_id != new_schema_id

updated_schema = await orion_client.read_block_document_by_name(
name="test", block_type_slug="baz"
)
updated_schema_id = updated_schema.block_schema_id

# new local schema ID should now be saved to Orion
assert updated_schema_id == new_schema_id

0 comments on commit 90a94af

Please sign in to comment.