Skip to content

Commit

Permalink
feature: Add PeristedJobData schema (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijc committed Sep 13, 2021
1 parent ccd659d commit e721c72
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/braket/jobs_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.jobs_data.persisted_job_data_v1 import ( # noqa: F401
PersistedJobData,
PersistedJobDataFormat,
)
57 changes: 57 additions & 0 deletions src/braket/jobs_data/persisted_job_data_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License

from enum import Enum
from typing import Any, Dict

from pydantic import Field

from braket.schema_common import BraketSchemaBase, BraketSchemaHeader


class PersistedJobDataFormat(str, Enum):
"""
Enum class for the the required formats.
"""

PLAINTEXT = "plaintext"
# Pickle data format with protocol version 4 (Data is base64 encoded after pickling)
PICKLED_V4 = "pickled_v4"


class PersistedJobData(BraketSchemaBase):
"""
The schema used for persisting data during Amazon Braket job executions.
Attributes:
braketSchemaHeader (BraketSchemaHeader): Schema header. Users do not need
to set this value.
dataDictionary (Dict[str, Any]): Dict representing the data to be persisted.
dataFormat (PersistedJobDataFormat): Data format used for persisting the values
in `dataDictionary`.
Examples:
>>> data_to_persist = {"some_key": "some_value", "more_keys": True}
>>> PersistedJobData(dataDictionary=data_to_persist,
>>> dataFormat=PersistedJobDataFormat.PLAINTEXT)
"""

_PERSISTED_JOB_DATA_HEADER = BraketSchemaHeader(
name="braket.jobs_data.persisted_job_data", version="1"
)

braketSchemaHeader: BraketSchemaHeader = Field(
default=_PERSISTED_JOB_DATA_HEADER, const=_PERSISTED_JOB_DATA_HEADER
)
dataDictionary: Dict[str, Any]
dataFormat: PersistedJobDataFormat
61 changes: 61 additions & 0 deletions test/unit_tests/braket/jobs_data/test_persisted_job_data_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json

import pytest
from jsonschema import validate
from pydantic import ValidationError

from braket.jobs_data.persisted_job_data_v1 import PersistedJobData, PersistedJobDataFormat


def test_persisted_job_data_fields():
data_dict = {"key_1": "value_1", "iterations": 2, "more_keys": True}
data_format = PersistedJobDataFormat.PLAINTEXT
persisted = PersistedJobData(dataDictionary=data_dict, dataFormat=data_format)
assert persisted.dataDictionary == data_dict
assert persisted.dataFormat == data_format


@pytest.mark.xfail(raises=ValidationError)
def test_persisted_job_data_missing_data_format():
PersistedJobData(dataDictionary={"a": 1})


@pytest.mark.xfail(raises=ValidationError)
def test_persisted_job_data_missing_data_dictionary():
PersistedJobData(dataFormat=PersistedJobDataFormat.PLAINTEXT)


def test_json_validates_against_schema():
persisted_job_data = PersistedJobData(
dataDictionary={"a": 1}, dataFormat=PersistedJobDataFormat.PLAINTEXT
)
validate(json.loads(persisted_job_data.json()), persisted_job_data.schema())


def test_persisted_job_data_parses_json():
json_str = json.dumps(
{
"braketSchemaHeader": {
"name": "braket.jobs_data.persisted_job_data",
"version": "1",
},
"dataDictionary": {"converged": True, "energy": -0.2},
"dataFormat": "plaintext",
}
)
persisted_data = PersistedJobData.parse_raw(json_str)
assert persisted_data.dataDictionary == {"converged": True, "energy": -0.2}
assert persisted_data.dataFormat == PersistedJobDataFormat.PLAINTEXT
1 change: 1 addition & 0 deletions test/unit_tests/braket/schema_common/test_schema_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_get_schema_class_invalid_name():
"braket.task_result.rigetti_metadata",
"braket.task_result.simulator_metadata",
"braket.task_result.task_metadata",
"braket.jobs_data.persisted_job_data",
],
)
def test_no_header_typos(name):
Expand Down

0 comments on commit e721c72

Please sign in to comment.