## Auth

In [1]:
import os
from dotenv import load_dotenv

In [2]:
# Load environment variables from secrets.env
load_dotenv('.env')

# Retrieve WANDB_API_KEY
wandb_api_key = os.getenv('WANDB_API_KEY')

In [3]:
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = './.aws/credentials'
os.environ['AWS_CONFIG_FILE'] = './.aws/config'
os.environ['AWS_DEFAULT_REGION'] = 'eu-north-1'

## Training Job

In [4]:
import time
from sagemaker_core.helper.session_helper import Session, get_execution_role

In [5]:
sagemaker_session = Session()

In [6]:
from sagemaker_core.resources import TrainingJob
from sagemaker_core.shapes import (
    AlgorithmSpecification,
    Channel,
    DataSource,
    S3DataSource,
    ResourceConfig,
    StoppingCondition,
    OutputDataConfig
)

In [7]:
max_hours_runtime = 12

In [8]:
bucket = ''
region = 'eu-north-1'

image = ''
role = ''

instance_type = "ml.g4dn.2xlarge"  # SageMaker instance type to use for training
instance_count = 1  # Number of instances to use for training
volume_size_in_gb = 30
max_runtime_in_seconds = 3600 * max_hours_runtime  # Maximum runtime. Job exits if it doesn't finish before this time

s3_output_path = f"s3://{bucket}/output"
dataset_uri = f's3://{bucket}/datasets/'

### Standard

In [None]:
dataset = 'abalone'
job_name = f"synthetic-data-generation-{dataset}" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

hyper_parameters = {
    "dataset": dataset,
}

environment = {
    "WANDB_API_KEY": wandb_api_key,
}

In [None]:
training_job = TrainingJob.create(
    training_job_name=job_name,
    hyper_parameters=hyper_parameters,
    environment=environment,
    algorithm_specification=AlgorithmSpecification(
        training_image=image,
        training_input_mode="File"
    ),
    role_arn=role,
    input_data_config=[
        Channel(
            channel_name=f"datasets",
            content_type="csv",
            data_source=DataSource(
                s3_data_source=S3DataSource(
                    s3_data_type="S3Prefix",
                    s3_uri=dataset_uri,
                    s3_data_distribution_type="FullyReplicated",
                )
            ),
        ),
    ],
    output_data_config=OutputDataConfig(
        s3_output_path=s3_output_path,
    ),
    resource_config=ResourceConfig(
        instance_type=instance_type,
        instance_count=instance_count,
        volume_size_in_gb=volume_size_in_gb,
    ),
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=max_runtime_in_seconds
    ),
)

training_job.wait()

### Custom Entrypoint 

In [24]:
dataset = 'abalone'
job_name = f"synthetic-data-generation-{dataset}" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

hyper_parameters = {
    "dataset": dataset,
    "n_trials_per_worker": str(14),
    "n_workers": str(3)
}

entry_script_uri = f's3://{bucket}/entrypoint/'

environment = {
    "WANDB_API_KEY": wandb_api_key,
    "SAGEMAKER_PROGRAM": "/opt/ml/input/data/entrypoint/run_pipeline_copula_gan.py"
}

In [None]:
training_job = TrainingJob.create(
    training_job_name=job_name,
    hyper_parameters=hyper_parameters,
    environment=environment,
    algorithm_specification=AlgorithmSpecification(
        training_image=image,
        training_input_mode="File"
    ),
    role_arn=role,
    input_data_config=[
        Channel(
            channel_name=f"datasets",
            content_type="csv",
            data_source=DataSource(
                s3_data_source=S3DataSource(
                    s3_data_type="S3Prefix",
                    s3_uri=dataset_uri,
                    s3_data_distribution_type="FullyReplicated",
                )
            ),
        ),
        Channel(
            channel_name="entrypoint",
            content_type="application/x-python",
            data_source=DataSource(
                s3_data_source=S3DataSource(
                    s3_data_type='S3Prefix',
                    s3_uri=entry_script_uri,
                    s3_data_distribution_type='FullyReplicated',
                )
            )
        )
    ],
    output_data_config=OutputDataConfig(
        s3_output_path=s3_output_path,
    ),
    resource_config=ResourceConfig(
        instance_type=instance_type,
        instance_count=instance_count,
        volume_size_in_gb=volume_size_in_gb,
    ),
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=max_runtime_in_seconds
    ),
)

training_job.wait()