Skip to content

Commit

Permalink
feat: wrap generate_keys() in future (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
RahulDubey391 committed Nov 28, 2023
1 parent 5d95ee2 commit 964deb0
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
11 changes: 5 additions & 6 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._keys = generate_keys()
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
self._client: Optional[AlloyDBClient] = None

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -123,11 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if instance_uri in self._instances:
instance = self._instances[instance_uri]
else:
instance = Instance(
instance_uri,
self._client,
self._keys,
)
instance = Instance(instance_uri, self._client, self._keys)
self._instances[instance_uri] = instance

connect_func = {
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self,
instance_uri: str,
client: AlloyDBClient,
keys: Tuple[rsa.RSAPrivateKey, str],
keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]],
) -> None:
# validate and parse instance_uri
instance_uri_split = instance_uri.split("/")
Expand Down Expand Up @@ -98,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult:

try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = self._keys
priv_key, pub_key = await self._keys
# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _write_to_file(
return (ca_filename, cert_chain_filename, key_filename)


def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
pub_key = (
priv_key.public_key()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def test__get_client_certificate(
Test _get_client_certificate returns successfully.
"""
test_client = AlloyDBClient("", "", credentials, client)
keys = generate_keys()
keys = await generate_keys()
certs = await test_client._get_client_certificate(
"test-project", "test-region", "test-cluster", keys[1]
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_Instance_init() -> None:
Test to check whether the __init__ method of Instance
can tell if the instance URI that's passed in is formatted correctly.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
async with aiohttp.ClientSession() as client:
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -52,7 +52,7 @@ async def test_Instance_init_invalid_instant_uri() -> None:
Test to check whether the __init__ method of Instance
will throw error for invalid instance URI.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
async with aiohttp.ClientSession() as client:
with pytest.raises(ValueError):
Instance("invalid/instance/uri/", client, keys)
Expand All @@ -64,7 +64,7 @@ async def test_Instance_close() -> None:
Test that Instance's close method
cancels tasks gracefully.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -84,7 +84,7 @@ async def test_Instance_close() -> None:
@pytest.mark.asyncio
async def test_perform_refresh() -> None:
"""Test that _perform refresh returns valid RefreshResult"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -104,7 +104,7 @@ async def test_schedule_refresh_replaces_result() -> None:
Test to check whether _schedule_refresh replaces a valid refresh result
with another refresh result.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -131,7 +131,7 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None
Test to check whether _schedule_refresh won't replace a valid
refresh result with an invalid one.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_schedule_refresh_expired_cert() -> None:
Test to check whether _schedule_refresh will throw RefreshError on
expired certificate.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
# set certificate to be expired
client.instance.cert_before = datetime.now() - timedelta(minutes=20)
Expand All @@ -182,7 +182,7 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down

0 comments on commit 964deb0

Please sign in to comment.