In [1]:
import boto3
import sagemaker
import boto3
import os
from sagemaker.pytorch import PyTorch
from sklearn.model_selection import train_test_split
from sagemaker.feature_store.feature_group import FeatureGroup

boto_session = boto3.Session()
region = boto_session.region_name

sm_session = sagemaker.Session()
sm_client = boto_session.client("sagemaker")
sm_role = sagemaker.get_execution_role()

s3_client = boto3.client('s3')

# Define your feature group name and region
feature_group_name = 'fire-image-feature-group'

# Athena client
athena_client = boto3.client('athena', region_name=region)

# MLFLow
tracking_server_arn = 'arn:aws:sagemaker:eu-central-1:567821811420:mlflow-tracking-server/wildfire-mj'
experiment_name = 'wildfire-classification'

bucket = 'wildfires'
prefix = 'sagemaker/fire-image-classification'

# FS query
query = """SELECT *
    FROM "AwsDataCatalog"."sagemaker_featurestore"."fire_image_feature_group_1718694943";
    """

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


## Preparing Data from feature store

In [2]:
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sm_session)

def get_responce_athena():
    # Run query
    response = athena_client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': 'sagemaker_featurestore'  # Replace with your Athena database name
        },
        ResultConfiguration={
            'OutputLocation': 's3://wildfires/feature-store-output/'  # Replace with your S3 bucket
        }
    )
    
    # Get query execution ID
    query_execution_id = response['QueryExecutionId']
    
    # Wait for the query to complete
    status = 'RUNNING'
    while status != 'SUCCEEDED':
        response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
        status = response['QueryExecution']['Status']['State']
    
    # Get the results
    response = athena_client.get_query_results(QueryExecutionId=query_execution_id)

    return response

response = get_responce_athena()

In [3]:
def download_images(metadata, download_dir='images'):
    if not os.path.exists(download_dir):
        os.makedirs(download_dir)

    for record in metadata:
        image_location = record['image_location']
        bucket, key = image_location.replace('s3://', '').split('/', 1)
        local_path = os.path.join(download_dir, os.path.basename(key))

        s3_client.download_file(bucket, key, local_path)

        record['local_path'] = local_path  # Add the local path to the record

    return metadata

def get_metadata(response):
    rows = [row['Data'] for row in response['ResultSet']['Rows'][1:]]
    columns = [col['VarCharValue'] for col in response['ResultSet']['Rows'][0]['Data']]

    metadata = [
        {
            'image_id': row[0]['VarCharValue'],
            'image_location': row[1]['VarCharValue'],
            'label': int(row[2]['VarCharValue']),
            'image_type': row[3]['VarCharValue'],
            'event_time': row[4]['VarCharValue'],
        } for row in rows
    ]

    metadata = download_images(metadata)

    return metadata 

metadata = get_metadata(response)

In [4]:
# Split the metadata into train, validation, and test sets
train_metadata, test_metadata = train_test_split(metadata, test_size=0.2, stratify=[m['label'] for m in metadata], random_state=42)
train_metadata, val_metadata = train_test_split(train_metadata, test_size=0.25, stratify=[m['label'] for m in train_metadata], random_state=42)

print(f"Training samples: {len(train_metadata)}")
print(f"Validation samples: {len(val_metadata)}")
print(f"Test samples: {len(test_metadata)}")

Training samples: 599
Validation samples: 200
Test samples: 200


In [8]:
!mkdir datajob

In [22]:
%%writefile requirements.txt
mlflow==2.13.2
torchinfo
sagemaker-mlflow==0.1.0

Writing requirements.txt


In [19]:
import pickle

# Save to a pickle file
with open('datajob/train.pkl', 'wb') as f:
    pickle.dump(train_metadata, f)
with open('datajob/val.pkl', 'wb') as f:
    pickle.dump(val_metadata, f)
with open('datajob/test.pkl', 'wb') as f:
    pickle.dump(test_metadata, f)

train_data_path = sm_session.upload_data(path='datajob', bucket=bucket, key_prefix=prefix)
images_data_path = sm_session.upload_data(path='images', bucket=bucket, key_prefix=prefix)

print(train_data_path, images_data_path)

s3://wildfires/sagemaker/fire-image-classification s3://wildfires/sagemaker/fire-image-classification


## Train Job

In [27]:
run_name = 'train-resnet-fire'

new_estimator = PyTorch(
    entry_point='train-mlflow.py',
    role=sm_role,
    instance_count=1,
    instance_type="ml.p3.2xlarge",
    py_version="py39",
    framework_version="1.13",
    hyperparameters={
        'num-epochs': 10,
        'batch-size': 32,
        'learning-rate': 0.1,
        'run-name': run_name
    },
    input_mode='File',
    environment={
        'MLFLOW_TRACKING_URIs': tracking_server_arn,
        'MLFLOW_EXPERIMENT_NAME': experiment_name,
    },
    dependencies=['requirements.txt'],
    output_path=f's3://{bucket}/{prefix}/'
)

# Fit the estimator
new_estimator.fit({'train': train_data_path})

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2024-06-22-16-20-33-531


2024-06-22 16:20:33 Starting - Starting the training job...
2024-06-22 16:21:01 Pending - Preparing the instances for training......
2024-06-22 16:21:40 Downloading - Downloading input data...
2024-06-22 16:22:10 Downloading - Downloading the training image..................
2024-06-22 16:25:06 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2024-06-22 16:25:24,540 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2024-06-22 16:25:24,561 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-06-22 16:25:24,574 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2024-06-22 16:25:24,576 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2024-06-22 16:25:25,947 