Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(jans-pycloudlib): split aws secrets when payload is larger than 65536 bytes #3971

Merged
merged 1 commit into from
Feb 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 107 additions & 81 deletions jans-pycloudlib/jans/pycloudlib/secret/aws_secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import lzma
import os
import sys
import typing as _t
from contextlib import suppress
from functools import cached_property
Expand All @@ -14,40 +15,19 @@
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import NoRegionError
from math import ceil

from jans.pycloudlib.secret.base_secret import BaseSecret
from jans.pycloudlib.utils import safe_value

logger = logging.getLogger(__name__)


def _dump_value(value: _t.Any) -> bytes:
"""Dump compressed bytes from any Python data type.

Args:
value: Any given value.

Returns:
Compressed bytes contains the value.
"""
return lzma.compress(json.dumps(value).encode())


def _load_value(value: bytes) -> _t.Any:
"""Load compressed bytes into any Python data type.

Args:
value: Any given value

Returns:
Any Python data type.
"""
return json.loads(lzma.decompress(value).decode())


class AwsSecret(BaseSecret):
"""This class interacts with AWS Secrets Manager backend.

If the secret's size is larger than the size limit (64KB), it will be stored in multiple secrets (maximum 10).

The instance of this class is configured via environment variables.

Supported environment variables:
Expand Down Expand Up @@ -94,7 +74,7 @@ class AwsSecret(BaseSecret):
```
"""

def __init__(self) -> None:
def __init__(self) -> None: # noqa: D107
# unique name used as prefix to distinguish with other secrets
# a typical usage is to use vendor/organization name
prefix = os.environ.get("CN_AWS_SECRETS_PREFIX", "jans")
Expand All @@ -103,8 +83,11 @@ def __init__(self) -> None:
# see https://docs.aws.amazon.com/cli/latest/reference/secretsmanager/create-secret.html#options
self.basepath = f"{prefix}_secrets"

# flag to determine whether AWS secrets already created
self.basepath_exists = False
# iterable contains multipart secret names
self.multiparts: list[str] = []

# max size of payload (currently 64K)
self.max_payload_size = 65536

@cached_property
def client(self) -> boto3.session.Session.client:
Expand All @@ -129,11 +112,26 @@ def get_all(self) -> dict[str, _t.Any]:
Returns:
A mapping of secrets (if any).
"""
self._prepare_secret()
resp = self.client.get_secret_value(SecretId=self.basepath)
# get all existing multipart secrets
resp = self.client.list_secrets(
Filters=[{"Key": "name", "Values": [self.basepath]}],
)
names = [secret["Name"] for secret in resp["SecretList"]]

if not names:
return {}

payload = b"".join([
self.client.get_secret_value(SecretId=name)["SecretBinary"]
for name in names
])

# SecretBinary is a `dict` data type
data: dict[str, _t.Any] = _load_value(resp["SecretBinary"])
try:
# previously data is compressed using lzma
data: dict[str, _t.Any] = json.loads(lzma.decompress(payload).decode())
logger.warning("Loaded legacy data.")
except lzma.LZMAError:
data = json.loads(payload.decode())
return data

def get(self, key: str, default: _t.Any = "") -> _t.Any:
Expand All @@ -157,16 +155,11 @@ def set(self, key: str, value: _t.Any) -> bool:
value: Value of the key.

Returns:
A boolean to mark whether config is set or not.
A boolean to indicate if secret was set successfully.
"""
data = self.get_all()
data[key] = safe_value(value)

resp = self.client.update_secret(
SecretId=self.basepath,
SecretBinary=_dump_value(data),
)
return bool(resp)
return self._update_secret_multipart(json.dumps(data))

def set_all(self, data: dict[str, _t.Any]) -> bool:
"""Set all key-value pairs.
Expand All @@ -175,37 +168,89 @@ def set_all(self, data: dict[str, _t.Any]) -> bool:
data: key-value pairs of secrets.

Returns:
A boolean indicating operation is succeed or not.
A boolean indicating if the operation was successful.
"""
self._prepare_secret()

# fetch existing data (if any) as we will merge them;
# note that existing value will be overwritten
payload = self.get_all()

for k, v in data.items():
# ensure value that has bytes is converted to text
# ensure key-value that has bytes is converted to text
payload[k] = safe_value(v)
return self._update_secret_multipart(json.dumps(payload))

resp = self.client.update_secret(
SecretId=self.basepath,
SecretBinary=_dump_value(payload),
)
return bool(resp)
@cached_property
def replica_regions(self) -> list[dict[str, _t.Any]]:
"""Get replica regions specified in a file.

def _prepare_secret(self) -> None:
"""Prepare (create if missing) secrets with empty value."""
# check whether secrets already exists
if self.basepath_exists:
return
The location of the file is pointed by `CN_AWS_SECRETS_REPLICA_FILE` environment variable.
"""
regions = []

with suppress(FileNotFoundError, TypeError, IsADirectoryError):
file_ = os.environ.get("CN_AWS_SECRETS_REPLICA_FILE", "")
try:
txt = Path(file_).read_text().strip()
regions = json.loads(txt)
except json.decoder.JSONDecodeError as exc:
raise ValueError(f"Unable to load replica regions from {file_}; reason={exc}")
else:
# ensure regions does not include current client's region
regions = [
region for region in regions
if region["Region"] != self.client.meta.region_name
]
return regions

def _update_secret_multipart(self, payload: _t.AnyStr) -> bool: # noqa: D102
if isinstance(payload, str):
# Convert the string payload into a bytes. This step can be omitted if you
# pass in bytes instead of a str for the payload argument.
payload_bytes = payload.encode()
else:
payload_bytes = payload

data_length = sys.getsizeof(payload_bytes)
parts = ceil(data_length / self.max_payload_size)

if parts > 1:
logger.warning(
f"The secret payload size is {data_length} bytes and is exceeding max. size of {self.max_payload_size} bytes. "
f"It will be splitted into {parts} parts."
)

for part in range(0, parts):
name = self._prepare_secret_multipart(part)
start_bytes = part * self.max_payload_size
stop_bytes = (part + 1) * self.max_payload_size
fragment = payload_bytes[start_bytes:stop_bytes]
self.client.update_secret(SecretId=name, SecretBinary=fragment)
return True

def _prepare_secret_multipart(self, part: int) -> str:
"""Check individual secrets if they exist or create new secrets with empty value if they don't.

Args:
part: part number of a multipart secret.

Returns:
Newly created secret's name
"""
name = self.basepath

if part > 0:
name = f"{self.basepath}_{part}"

if name in self.multiparts:
return name

try:
# get the secret
self.client.get_secret_value(SecretId=self.basepath)
self.client.get_secret_value(SecretId=name)

# mark the secret as exists so subsequent checks made by
# client instance won't need to make requests to AWS service
self.basepath_exists = True
self.multiparts.append(name)

except ClientError as exc:
# raise exception if not related to missing secrets;
Expand All @@ -215,9 +260,10 @@ def _prepare_secret(self) -> None:

create_secret = partial(
self.client.create_secret,
Name=self.basepath,
SecretBinary=_dump_value({}),
Name=name,
SecretBinary=json.dumps({}),
Description="Secrets for Janssen cluster",
Tags=[{"Key": "multipart_enabled", "Value": "true"}],
)

if self.replica_regions:
Expand All @@ -231,9 +277,11 @@ def _prepare_secret(self) -> None:
# run the actual secrets creation
create_secret()

# mark the secrets as exists so subsequent checks made by
logger.info(f"Created secret: {name}")

# mark the secret as exists so subsequent checks made by
# client instance won't need to make requests to AWS service
self.basepath_exists = True
self.multiparts.append(name)

except NoCredentialsError:
raise RuntimeError(
Expand All @@ -242,26 +290,4 @@ def _prepare_secret(self) -> None:
"by AWS_SHARED_CREDENTIALS_FILE environment variable, or specify profile "
"name via AWS_PROFILE environment variable."
)

@cached_property
def replica_regions(self) -> list[dict[str, _t.Any]]:
"""Get replica regions specified in a file.

The location of the file is pointed by `CN_AWS_SECRETS_REPLICA_FILE` environment variable.
"""
regions = []

with suppress(FileNotFoundError, TypeError, IsADirectoryError):
file_ = os.environ.get("CN_AWS_SECRETS_REPLICA_FILE", "")
try:
txt = Path(file_).read_text().strip()
regions = json.loads(txt)
except json.decoder.JSONDecodeError as exc:
raise ValueError(f"Unable to load replica regions from {file_}; reason={exc}")
else:
# ensure regions does not include current client's region
regions = [
region for region in regions
if region["Region"] != self.client.meta.region_name
]
return regions
return name