Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,16 @@ def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session
"""Extract region from Subnet ID."""
session: boto3.Session = ensure_session(session=boto3_session)
client_ec2: boto3.client = client(service_name="ec2", session=session)
# This is wrong, when using region ap-south-1
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9]


def get_region_from_session(boto3_session: Optional[boto3.Session] = None) -> str:
"""Extract region from session."""
session: boto3.Session = ensure_session(session=boto3_session)
return session.region_name


def extract_partitions_from_paths(
path: str, paths: List[str]
) -> Tuple[Optional[Dict[str, str]], Optional[Dict[str, List[str]]]]:
Expand Down
15 changes: 9 additions & 6 deletions awswrangler/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_default_logging_path(
_account_id = account_id
if (region is None) and (subnet_id is not None):
boto3_session = _utils.ensure_session(session=boto3_session)
_region: str = _utils.get_region_from_subnet(subnet_id=subnet_id, boto3_session=boto3_session)
_region: str = _utils.get_region_from_session(boto3_session=boto3_session)
elif (region is None) and (subnet_id is None):
raise exceptions.InvalidArgumentCombination("You must pass region or subnet_id or both.")
else:
Expand All @@ -63,7 +63,7 @@ def _get_default_logging_path(

def _build_cluster_args(**pars): # pylint: disable=too-many-branches,too-many-statements
account_id: str = _utils.get_account_id(boto3_session=pars["boto3_session"])
region: str = _utils.get_region_from_subnet(subnet_id=pars["subnet_id"], boto3_session=pars["boto3_session"])
region: str = _utils.get_region_from_session(boto3_session=pars["boto3_session"])

# S3 Logging path
if pars.get("logging_s3_path") is None:
Expand Down Expand Up @@ -155,6 +155,7 @@ def _build_cluster_args(**pars): # pylint: disable=too-many-branches,too-many-s
],
}
)

if spark_env is not None:
args["Configurations"].append(
{
Expand Down Expand Up @@ -934,7 +935,9 @@ def submit_ecr_credentials_refresh(
session: boto3.Session = _utils.ensure_session(session=boto3_session)
client_s3: boto3.client = _utils.client(service_name="s3", session=session)
bucket, key = _utils.parse_path(path=path_script)
client_s3.put_object(Body=_get_ecr_credentials_refresh_content().encode(encoding="utf-8"), Bucket=bucket, Key=key)
region: str = _utils.get_region_from_session(boto3_session=boto3_session)
client_s3.put_object(
Body=_get_ecr_credentials_refresh_content(region).encode(encoding="utf-8"), Bucket=bucket, Key=key)
command: str = f"spark-submit --deploy-mode cluster {path_script}"
name: str = "ECR Credentials Refresh"
step: Dict[str, Any] = build_step(
Expand All @@ -946,14 +949,14 @@ def submit_ecr_credentials_refresh(
return response["StepIds"][0]


def _get_ecr_credentials_refresh_content() -> str:
return """
def _get_ecr_credentials_refresh_content(region) -> str:
return f"""
import subprocess
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ECR Setup Job").getOrCreate()

COMMANDS = [
"sudo -s eval $(aws ecr get-login --region us-east-1 --no-include-email)",
"sudo -s eval $(aws ecr get-login --region {region} --no-include-email)",
"sudo hdfs dfs -put -f /root/.docker/config.json /user/hadoop/"
]

Expand Down