Skip to content

Commit

Permalink
adding AWS Dataservices connectors utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
skondakindi committed Apr 2, 2024
1 parent db2cc4f commit 20e09f8
Show file tree
Hide file tree
Showing 16 changed files with 1,493 additions and 305 deletions.
16 changes: 16 additions & 0 deletions numalogic/connectors/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum, EnumMeta


class MetaEnum(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True


class BaseEnum(Enum, metaclass=MetaEnum):
@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))
138 changes: 138 additions & 0 deletions numalogic/connectors/aws/boto3_client_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from boto3 import Session
import logging
from numalogic.connectors.aws import BaseEnum
from numalogic.connectors.aws.exceptions import UnRecognizedAWSClientException
from numalogic.connectors.aws.sts_client_manager import STSClientManager
from numalogic.connectors.aws.db_configurations import (
load_db_conf,
DatabaseServiceProvider,
DatabaseTypes,
)

logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)


class Boto3ClientManager:

def __init__(self, configurations):
"""
Initializes the Boto3ClientManager with the given configurations.
The Boto3ClientManager is responsible for managing AWS clients for different services like RDS and Athena.
It uses the configurations to create the clients and manage their sessions.
Args:
configurations (object): An object containing the necessary configurations. The configurations should include:
- aws_assume_role_arn: The ARN of the role to assume for AWS services.
- aws_assume_role_session_name: The session name to use when assuming the role.
- endpoint: The endpoint for the AWS service.
- port: The port to use for the AWS service.
- database_username: The username for the database.
- aws_region: The AWS region where the services are located.
Attributes:
rds_client (boto3.client): The client for AWS RDS service. Initialized as None.
athena_client (boto3.client): The client for AWS Athena service. Initialized as None.
configurations (object): The configurations for the AWS services.
sts_client_manager (STSClientManager): The STSClientManager for managing AWS STS sessions.
"""
self.rds_client = None
self.athena_client = None
self.configurations = configurations
self.sts_client_manager = STSClientManager()

def get_boto3_session(self) -> Session:
"""
Returns a Boto3 session object with the necessary credentials.
This method retrieves the credentials from the STSClientManager using the given AWS assume role ARN and
session name. It then creates a Boto3 session object with the retrieved credentials and returns it.
Returns:
Session: A Boto3 session object with the necessary credentials.
"""
credentials = self.sts_client_manager.get_credentials(
self.configurations.aws_assume_role_arn,
self.configurations.aws_assume_role_session_name,
)
tmp_access_key = credentials["AccessKeyId"]
tmp_secret_key = credentials["SecretAccessKey"]
security_token = credentials["SessionToken"]
boto3_session = Session(
aws_access_key_id=tmp_access_key,
aws_secret_access_key=tmp_secret_key,
aws_session_token=security_token,
)
return boto3_session

def get_rds_token(self, rds_boto3_client) -> str:
"""
Generates an RDS authentication token using the provided RDS boto3 client.
This method generates an RDS authentication token by calling the 'generate_db_auth_token' method of the
provided RDS boto3 client. The authentication token is generated using the following parameters: -
DBHostname: The endpoint of the RDS database. - Port: The port number of the RDS database. - DBUsername: The
username for the RDS database. - Region: The AWS region where the RDS database is located.
Parameters:
rds_boto3_client (boto3.client): The RDS boto3 client used to generate the authentication token.
Returns:
str: The generated RDS authentication token.
"""
rds_token = rds_boto3_client.generate_db_auth_token(
DBHostname=self.configurations.endpoint,
Port=self.configurations.port,
DBUsername=self.configurations.database_username,
Region=self.configurations.aws_region,
)
return rds_token

def get_client(self, client_type: str):
"""
Generates an AWS client based on the provided client type.
This method generates an AWS client based on the provided client type. It first checks if the client type is
recognized by checking if it exists in the `DatabaseServiceProvider` enum. If the client type is recognized,
it creates the corresponding AWS client using the `get_boto3_session().client()` method and returns the
client object.
Parameters: client_type (str): The type of AWS client to generate. This should be one of the values defined
in the `DatabaseServiceProvider` enum.
Returns:
boto3.client: The generated AWS client object.
Raises: UnRecognizedAWSClientException: If the client type is not recognized, an exception is raised with a
message indicating the unrecognized client type and the available options.
"""
_LOGGER.debug(
f"Generating AWS client for client_type: {client_type} , and configurations: {str(self.configurations)}"
)
if client_type in DatabaseServiceProvider:
if client_type == DatabaseServiceProvider.rds.value:
self.rds_client = self.get_boto3_session().client(
"rds", region_name=self.configurations.aws_region
)
return self.rds_client
if client_type == DatabaseServiceProvider.athena.value:
self.athena_client = self.get_boto3_session().client(
"athena", region_name=self.configurations.aws_region
)
else:
raise UnRecognizedAWSClientException(
f"Unrecognized Client Type : {client_type}, please choose one from {DatabaseServiceProvider.list()}"
)

#
# if __name__ == "__main__":
# config = load_db_conf(
# "./db_config.yaml")
# boto3_client_manager = Boto3ClientManager(config)
# rds = DatabaseServiceProvider.rds.value
# rds_client = boto3_client_manager.get_client(rds)
# _LOGGER.info(boto3_client_manager.get_rds_token(rds_client))
166 changes: 166 additions & 0 deletions numalogic/connectors/aws/db_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
import logging
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import OmegaConf
from numalogic.connectors.aws import BaseEnum
from numalogic.connectors.aws.exceptions import ConfigNotFoundError

_LOGGER = logging.getLogger(__name__)


class DatabaseServiceProvider(BaseEnum):
"""
A class representing the database service providers.
Attributes:
rds (str): Represents the RDS (Relational Database Service) provider.
athena (str): Represents the Athena provider.
"""

rds = "rds"
athena = "athena"


class DatabaseTypes(BaseEnum):
"""
A class representing different types of databases.
Attributes:
mysql (str): Represents the MySQL database type.
athena (str): Represents the Athena database type.
"""

mysql = "mysql"
athena = "athena"


@dataclass
class AWSConfig:
"""
Class representing AWS configuration.
Attributes:
aws_assume_role_arn (str): The ARN of the IAM role to assume.
aws_assume_role_session_name (str): The name of the session when assuming the IAM role.
"""

aws_assume_role_arn: str = ""
aws_assume_role_session_name: str = ""


@dataclass
class SSLConfig:
"""
SSLConfig class represents the configuration for SSL/TLS settings.
Attributes:
ca (Optional[str]): The path to the Certificate Authority (CA) file. Defaults to an empty string.
"""

ca: Optional[str] = ""


@dataclass
class RDBMSConfig:
"""
RDBMSConfig class represents the configuration for a Relational Database Management System (RDBMS).
Attributes:
endpoint (str): The endpoint or hostname of the database. Defaults to an empty string.
port (int): The port number of the database. Defaults to 3306.
database_name (str): The name of the database. Defaults to an empty string.
database_username (str): The username for the database connection. Defaults to an empty string.
database_password (str): The password for the database connection. Defaults to an empty string.
database_connection_timeout (int): The timeout duration for the database connection in seconds. Defaults to 10.
database_type (str): The type of the database. Defaults to 'mysql'.
database_provider (str): The provider of the database service. Defaults to 'rds'.
ssl_enabled (bool): Flag indicating whether SSL/TLS is enabled for the database connection. Defaults to False.
ssl (Optional[SSLConfig]): The SSL/TLS configuration for the database connection. Defaults to an empty SSLConfig object.
"""

endpoint: str = ""
port: int = 3306
database_name: str = ""
database_username: str = ""
database_password: str = ""
database_connection_timeout: int = 10
database_type: str = DatabaseTypes.mysql.value
database_provider: str = DatabaseServiceProvider.rds.value
ssl_enabled: bool = False
ssl: Optional[SSLConfig] = field(default_factory=lambda: SSLConfig())


@dataclass
class RDSConfig(AWSConfig, RDBMSConfig):
"""
Class representing the configuration for an RDS (Relational Database Service) instance.
Inherits from:
- AWSConfig: Class representing AWS configuration.
- RDBMSConfig: Class representing the configuration for a Relational Database Management System (RDBMS).
Attributes:
aws_assume_role_arn (str): The ARN of the IAM role to assume.
aws_assume_role_session_name (str): The name of the session when assuming the IAM role.
endpoint (str): The endpoint or hostname of the database. Defaults to an empty string.
port (int): The port number of the database. Defaults to 3306.
database_name (str): The name of the database. Defaults to an empty string.
database_username (str): The username for the database connection. Defaults to an empty string.
database_password (str): The password for the database connection. Defaults to an empty string.
database_connection_timeout (int): The timeout duration for the database connection in seconds. Defaults to 10.
database_type (str): The type of the database. Defaults to 'mysql'.
database_provider (str): The provider of the database service. Defaults to 'rds'.
ssl_enabled (bool): Flag indicating whether SSL/TLS is enabled for the database connection. Defaults to False.
ssl (Optional[SSLConfig]): The SSL/TLS configuration for the database connection. Defaults to an empty SSLConfig object.
aws_region (str): The AWS region for the RDS instance.
aws_rds_use_iam (bool): Flag indicating whether to use IAM authentication for the RDS instance. Defaults to False.
"""

aws_region: str = ""
aws_rds_use_iam: bool = False


def load_db_conf(*paths: str) -> RDSConfig:
"""
Load database configuration from one or more YAML files.
Parameters:
- paths (str): One or more paths to YAML files containing the database configuration.
Returns:
- RDSConfig: An instance of the RDSConfig class representing the loaded database configuration.
Raises:
- ConfigNotFoundError: If none of the given configuration file paths exist.
Example:
load_db_conf("/path/to/config.yaml", "/path/to/another/config.yaml")
"""
confs = []
for _path in paths:
try:
conf = OmegaConf.load(_path)
except FileNotFoundError:
_LOGGER.warning("Config file path: %s not found. Skipping...", _path)
continue
confs.append(conf)

if not confs:
_err_msg = f"None of the given conf paths exist: {paths}"
raise ConfigNotFoundError(_err_msg)

schema = OmegaConf.structured(RDSConfig)
conf = OmegaConf.merge(schema, *confs)
return OmegaConf.to_object(conf)


# if __name__ == "__main__":
# print(
# load_db_conf(
# "/Users/skondakindi/Desktop/codebase/odl/odl-ml-python-sdk/tests/resources/db_config.yaml"
# )
# )
Loading

0 comments on commit 20e09f8

Please sign in to comment.