In [1]:
!pip install mlflow==2.13.2 sagemaker-mlflow==0.1.0



In [2]:
!mlflow

Usage: mlflow [OPTIONS] COMMAND [ARGS]...

Options:
  --version  Show the version and exit.
  --help     Show this message and exit.

Commands:
  artifacts    Upload, list, and download artifacts from an MLflow...
  db           Commands for managing an MLflow tracking database.
  deployments  Deploy MLflow models to custom targets.
  doctor       Prints out useful information for debugging issues with MLflow.
  experiments  Manage experiments.
  gc           Permanently delete runs in the `deleted` lifecycle stage.
  models       Deploy MLflow models locally.
  recipes      Run MLflow Recipes and inspect recipe results.
  run          Run an MLflow project from the given URI.
  runs         Manage runs.
  sagemaker    Serve models on SageMaker.
  server       Run the MLflow tracking server.


In [3]:
import sagemaker
import boto3
import mlflow

# Initialize session
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

region = sagemaker_session.boto_session.region_name
athena_client = boto3.client('athena', region_name=region)

# Initialize SageMaker FeatureStore client
featurestore_client = boto3.client('sagemaker-featurestore-runtime')
tracking_server_arn = 'arn:aws:sagemaker:eu-central-1:567821811420:mlflow-tracking-server/wildfire-mj'
experiment_name = 'wildfire-resnet'

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


In [4]:
# mlflow.set_tracking_uri(tracking_server_arn)
# mlflow.autolog()

In [5]:

# mlflow.set_experiment(experiment_name)

# with mlflow.start_run(run_name=sagemaker.utils.name_from_base("HPODemo")) as run:
#     runid = run.info.run_id
#     # mlflow.autolog()

#     mlflow.log_params({'a': 12, 'b': 1234}, run_id=runid)

In [6]:
import pandas as pd
from io import BytesIO
import base64
from PIL import Image
import numpy as np
from sagemaker.feature_store.feature_group import FeatureGroup

# Define feature group
bucket = "fire-project-hs"
feature_group_name = "test-fire-image-hs"
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session)


In [7]:
def run_query(query):
    # Run query
    response = athena_client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': 'sagemaker_featurestore'
        },
        ResultConfiguration={
            'OutputLocation': 's3://'+ bucket + '/feature-store-output/'
        }
    )

    query_execution_id = response['QueryExecutionId']

    status = 'RUNNING'
    while status != 'SUCCEEDED':
        response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
        status = response['QueryExecution']['Status']['State']

    return query_execution_id


def get_query_results_dataframe(query_execution_id):
    results = []
    next_token = None
    while True:
        if next_token:
            response = athena_client.get_query_results(QueryExecutionId=query_execution_id, NextToken=next_token)
        else:
            response = athena_client.get_query_results(QueryExecutionId=query_execution_id)

        results.extend(response['ResultSet']['Rows'])
        next_token = response.get('NextToken')
        if not next_token:
            break

    column_info = results[0]['Data']
    columns = [col['VarCharValue'] for col in column_info]

    data = []
    for row in results[1:]:
        data.append([col.get('VarCharValue', None) for col in row['Data']])

    df = pd.DataFrame(data, columns=columns)
    return df


def convert_data(dataset, icolumn="image_data", lcolumn="label"):
    def decode_image(base64_str):
        img_data = base64.b64decode(base64_str)
        img = Image.open(BytesIO(img_data))
        return np.array(img)

    def decode_label(label):
        return 1 if label == 'fire' else 0

    dataset[icolumn] = dataset[icolumn].apply(decode_image)
    dataset[lcolumn] = dataset[lcolumn].apply(decode_label)

In [8]:
query = """SELECT *
FROM "sagemaker_featurestore"."test_fire_image_hs_1718785080"
"""
qid = run_query(query)
data = get_query_results_dataframe(qid)
convert_data(data, 'image_data', 'label')

In [9]:
'no fire', len(data.query("label == 0")), 'fire', len(data.query("label == 1"))

('no fire', 488, 'fire', 1510)

In [10]:
!mkdir datanp

In [11]:
from sklearn.model_selection import train_test_split

X = np.stack(data['image_data'].values)
y = data['label'].values


X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Save data locally
np.save('datanp/X_train.npy', X_train)
np.save('datanp/X_val.npy', X_val)
np.save('datanp/y_train.npy', y_train)
np.save('datanp/y_val.npy', y_val)

In [12]:
# mlflow.set_tracking_uri(tracking_server_arn)
# mlflow.set_experiment(experiment_name)


# with mlflow.start_run(run_name=sagemaker.utils.name_from_base("Resnet-wildfire")) as run:
#     mlflow.log_params({
#         'train_split_size': len(X_train),
#         'val_split_size': len(X_val),
#         'epochs': 10,
#         'batch_size': 32
#     }, run_id=run.info.run_id)

In [13]:
%%writefile requirements.txt
mlflow==2.13.2
torchinfo
sagemaker-mlflow==0.1.0

Writing requirements.txt


In [14]:
!pip list | grep mlflow

mlflow                                2.13.2
sagemaker-mlflow                      0.1.0


In [16]:
from sagemaker.pytorch import PyTorch

# Save data locally or to S3
train_dir = 's3://{}/train'.format(bucket)
val_dir = 's3://{}/val'.format(bucket)
train_data_path = sagemaker_session.upload_data(path='datanp', bucket=bucket, key_prefix='train')
val_data_path = sagemaker_session.upload_data(path='datanp', bucket=bucket, key_prefix='val')

# Configure the PyTorch estimator and mlflow

run_base_name = 'resnet-gpu'

pytorch_estimator = PyTorch(
    entry_point='train-mlflow.py',
    role=role,
    instance_count=1,
    instance_type="ml.p3.2xlarge",
    py_version="py39",
    framework_version="1.13",
    hyperparameters={
        'epochs': 10,
        'batch_size': 32,
        'experiment_name': experiment_name,
        'run_name': run_base_name
    },
    input_mode='File',
    environment={
        'MLFLOW_TRACKING_URIs': tracking_server_arn,
        'MLFLOW_EXPERIMENT_NAME': experiment_name,
        # 'MLFLOW_PARENT_RUN_ID': run.info.run_id
    },
    dependencies=['requirements.txt']
)

# Launch the training job
pytorch_estimator.fit({'train': train_data_path, 'validation': val_data_path})

's3://fire-project-hs/val'