Skip to content

Commit

Permalink
Move filesystem and serializer validation to internal validators modu…
Browse files Browse the repository at this point in the history
…le (#12464)
  • Loading branch information
serinamarie authored Mar 28, 2024
1 parent 3dd982b commit 34d2e0a
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 99 deletions.
127 changes: 126 additions & 1 deletion src/prefect/_internal/schemas/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import json
import logging
import re
import urllib.parse
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import jsonschema
import pendulum

from prefect._internal.pydantic import HAS_PYDANTIC_V2
from prefect._internal.schemas.fields import DateTimeTZ
from prefect.exceptions import InvalidNameError
from prefect.exceptions import InvalidNameError, InvalidRepositoryURLError
from prefect.utilities.annotations import NotSet
from prefect.utilities.importtools import from_qualified_name
from prefect.utilities.names import generate_slug
from prefect.utilities.pydantic import JsonPatch

Expand Down Expand Up @@ -482,3 +485,125 @@ def get_or_create_state_name(v: str, values: dict) -> str:

def get_or_create_run_name(name):
return name or generate_slug(2)


### FILESYSTEM SCHEMA VALIDATORS ###


def stringify_path(value: Union[str, Path]) -> str:
if isinstance(value, Path):
return str(value)
return value


def validate_basepath(value: str) -> str:
scheme, netloc, _, _, _ = urllib.parse.urlsplit(value)

if not scheme:
raise ValueError(f"Base path must start with a scheme. Got {value!r}.")

if not netloc:
raise ValueError(
f"Base path must include a location after the scheme. Got {value!r}."
)

if scheme == "file":
raise ValueError(
"Base path scheme cannot be 'file'. Use `LocalFileSystem` instead for"
" local file access."
)

return value


def validate_github_access_token(v: str, values: dict) -> str:
"""Ensure that credentials are not provided with 'SSH' formatted GitHub URLs.
Note: validates `access_token` specifically so that it only fires when
private repositories are used.
"""
if v is not None:
if urllib.parse.urlparse(values["repository"]).scheme != "https":
raise InvalidRepositoryURLError(
"Crendentials can only be used with GitHub repositories "
"using the 'HTTPS' format. You must either remove the "
"credential if you wish to use the 'SSH' format and are not "
"using a private repository, or you must change the repository "
"URL to the 'HTTPS' format. "
)

return v


### SERIALIZER SCHEMA VALIDATORS ###


def validate_picklelib(value: str) -> str:
"""
Check that the given pickle library is importable and has dumps/loads methods.
"""
try:
pickler = from_qualified_name(value)
except (ImportError, AttributeError) as exc:
raise ValueError(
f"Failed to import requested pickle library: {value!r}."
) from exc

if not callable(getattr(pickler, "dumps", None)):
raise ValueError(f"Pickle library at {value!r} does not have a 'dumps' method.")

if not callable(getattr(pickler, "loads", None)):
raise ValueError(f"Pickle library at {value!r} does not have a 'loads' method.")

return value


def validate_dump_kwargs(value: dict) -> dict:
# `default` is set by `object_encoder`. A user provided callable would make this
# class unserializable anyway.
if "default" in value:
raise ValueError("`default` cannot be provided. Use `object_encoder` instead.")
return value


def validate_load_kwargs(value: dict) -> dict:
# `object_hook` is set by `object_decoder`. A user provided callable would make
# this class unserializable anyway.
if "object_hook" in value:
raise ValueError(
"`object_hook` cannot be provided. Use `object_decoder` instead."
)
return value


def cast_type_names_to_serializers(value):
from prefect.serializers import Serializer

if isinstance(value, str):
return Serializer(type=value)
return value


def validate_compressionlib(value: str) -> str:
"""
Check that the given pickle library is importable and has compress/decompress
methods.
"""
try:
compressor = from_qualified_name(value)
except (ImportError, AttributeError) as exc:
raise ValueError(
f"Failed to import requested compression library: {value!r}."
) from exc

if not callable(getattr(compressor, "compress", None)):
raise ValueError(
f"Compression library at {value!r} does not have a 'compress' method."
)

if not callable(getattr(compressor, "decompress", None)):
raise ValueError(
f"Compression library at {value!r} does not have a 'decompress' method."
)

return value
45 changes: 8 additions & 37 deletions src/prefect/filesystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
else:
from pydantic import Field, SecretStr, validator

from prefect._internal.schemas.validators import (
stringify_path,
validate_basepath,
validate_github_access_token,
)
from prefect.blocks.core import Block
from prefect.exceptions import InvalidRepositoryURLError
from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible
from prefect.utilities.compat import copytree
from prefect.utilities.filesystem import filter_files
Expand Down Expand Up @@ -97,9 +101,7 @@ class LocalFileSystem(WritableFileSystem, WritableDeploymentStorage):

@validator("basepath", pre=True)
def cast_pathlib(cls, value):
if isinstance(value, Path):
return str(value)
return value
return stringify_path(value)

def _resolve_path(self, path: str) -> Path:
# Only resolve the base path at runtime, default to the current directory
Expand Down Expand Up @@ -280,23 +282,7 @@ class RemoteFileSystem(WritableFileSystem, WritableDeploymentStorage):

@validator("basepath")
def check_basepath(cls, value):
scheme, netloc, _, _, _ = urllib.parse.urlsplit(value)

if not scheme:
raise ValueError(f"Base path must start with a scheme. Got {value!r}.")

if not netloc:
raise ValueError(
f"Base path must include a location after the scheme. Got {value!r}."
)

if scheme == "file":
raise ValueError(
"Base path scheme cannot be 'file'. Use `LocalFileSystem` instead for"
" local file access."
)

return value
return validate_basepath(value)

def _resolve_path(self, path: str) -> str:
base_scheme, base_netloc, base_urlpath, _, _ = urllib.parse.urlsplit(
Expand Down Expand Up @@ -945,22 +931,7 @@ class GitHub(ReadableDeploymentStorage):

@validator("access_token")
def _ensure_credentials_go_with_https(cls, v: str, values: dict) -> str:
"""Ensure that credentials are not provided with 'SSH' formatted GitHub URLs.
Note: validates `access_token` specifically so that it only fires when
private repositories are used.
"""
if v is not None:
if urllib.parse.urlparse(values["repository"]).scheme != "https":
raise InvalidRepositoryURLError(
"Crendentials can only be used with GitHub repositories "
"using the 'HTTPS' format. You must either remove the "
"credential if you wish to use the 'SSH' format and are not "
"using a private repository, or you must change the repository "
"URL to the 'HTTPS' format. "
)

return v
return validate_github_access_token(v, values)

def _create_repo_url(self) -> str:
"""Format the URL provided to the `git clone` command.
Expand Down
74 changes: 13 additions & 61 deletions src/prefect/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from typing import Any, Generic, Optional, TypeVar

from prefect._internal.pydantic import HAS_PYDANTIC_V2
from prefect._internal.schemas.validators import (
cast_type_names_to_serializers,
validate_compressionlib,
validate_dump_kwargs,
validate_load_kwargs,
validate_picklelib,
)

if HAS_PYDANTIC_V2:
import pydantic.v1 as pydantic
Expand Down Expand Up @@ -101,27 +108,7 @@ class PickleSerializer(Serializer):

@pydantic.validator("picklelib")
def check_picklelib(cls, value):
"""
Check that the given pickle library is importable and has dumps/loads methods.
"""
try:
pickler = from_qualified_name(value)
except (ImportError, AttributeError) as exc:
raise ValueError(
f"Failed to import requested pickle library: {value!r}."
) from exc

if not callable(getattr(pickler, "dumps", None)):
raise ValueError(
f"Pickle library at {value!r} does not have a 'dumps' method."
)

if not callable(getattr(pickler, "loads", None)):
raise ValueError(
f"Pickle library at {value!r} does not have a 'loads' method."
)

return value
return validate_picklelib(value)

@pydantic.root_validator
def check_picklelib_version(cls, values):
Expand Down Expand Up @@ -196,23 +183,11 @@ class JSONSerializer(Serializer):

@pydantic.validator("dumps_kwargs")
def dumps_kwargs_cannot_contain_default(cls, value):
# `default` is set by `object_encoder`. A user provided callable would make this
# class unserializable anyway.
if "default" in value:
raise ValueError(
"`default` cannot be provided. Use `object_encoder` instead."
)
return value
return validate_dump_kwargs(value)

@pydantic.validator("loads_kwargs")
def loads_kwargs_cannot_contain_object_hook(cls, value):
# `object_hook` is set by `object_decoder`. A user provided callable would make
# this class unserializable anyway.
if "object_hook" in value:
raise ValueError(
"`object_hook` cannot be provided. Use `object_decoder` instead."
)
return value
return validate_load_kwargs(value)

def dumps(self, data: Any) -> bytes:
json = from_qualified_name(self.jsonlib)
Expand Down Expand Up @@ -251,35 +226,12 @@ class CompressedSerializer(Serializer):
compressionlib: str = "lzma"

@pydantic.validator("serializer", pre=True)
def cast_type_names_to_serializers(cls, value):
if isinstance(value, str):
return Serializer(type=value)
return value
def validate_serializer(cls, value):
return cast_type_names_to_serializers(value)

@pydantic.validator("compressionlib")
def check_compressionlib(cls, value):
"""
Check that the given pickle library is importable and has compress/decompress
methods.
"""
try:
compressor = from_qualified_name(value)
except (ImportError, AttributeError) as exc:
raise ValueError(
f"Failed to import requested compression library: {value!r}."
) from exc

if not callable(getattr(compressor, "compress", None)):
raise ValueError(
f"Compression library at {value!r} does not have a 'compress' method."
)

if not callable(getattr(compressor, "decompress", None)):
raise ValueError(
f"Compression library at {value!r} does not have a 'decompress' method."
)

return value
return validate_compressionlib(value)

def dumps(self, obj: Any) -> bytes:
blob = self.serializer.dumps(obj)
Expand Down

0 comments on commit 34d2e0a

Please sign in to comment.