# FPE Rank Model

In [71]:
import boto3
import time
import json
import os
import pandas as pd
from io import StringIO, BytesIO
from urllib.parse import urlparse

SITE = "WESTKILL"

In [72]:
AWS_PROFILE="conte-prod"

model_bucket = "usgs-chs-conte-prod-fpe-models"
session = boto3.Session(profile_name=AWS_PROFILE)
session

Session(region_name='us-west-2')

In [59]:

def get_url_path(url):
    return urlparse(url).path[1:]

def create_manifest(session, data_bucket, data_key, image_bucket, filename = "images.manifest"):
    s3 = session.client('s3')

    # Read CSV file from S3
    print(f"downloading: {data_key}")
    csv_obj = s3.get_object(Bucket=data_bucket, Key=data_key)
    csv_data = csv_obj['Body'].read().decode('utf-8')
    data = pd.read_csv(StringIO(csv_data))
    print(f"rows: {len(data)}")

    # Extract 'URL' column and get keys
    manifest = data['url'].apply(get_url_path).to_list()
    manifest.insert(0, {"prefix": f"s3://{image_bucket}/"})

    # Upload JSON to S3
    prefix = os.path.dirname(data_key)
    manifest_key = f"{prefix}/{filename}"
    body = json.dumps(manifest)
    print(f"uploading: {manifest_key}")
    s3.put_object(Bucket=data_bucket, Key=manifest_key, Body=body)
    return f"s3://{data_bucket}/{manifest_key}"

In [60]:
create_manifest(
    session,
    data_bucket="usgs-chs-conte-prod-fpe-models",
    data_key=f"rank/{SITE}/data/flow-images-train.csv",
    image_bucket="usgs-chs-conte-prod-fpe-storage",
    filename="flow-images-train.manifest"
)

downloading: rank/WESTKILL/data/flow-images-train.csv
rows: 4622
uploading: rank/WESTKILL/data/flow-images-train.manifest


's3://usgs-chs-conte-prod-fpe-models/rank/WESTKILL/data/flow-images-train.manifest'

In [61]:
AWS_PROFILE="conte-prod"
AWS_REGION="us-west-2"
JOB_ROLE_ARN="arn:aws:iam::694155575325:role/fpe-prod-sagemaker-execution-role"

def timestamp():
    return time.strftime("%Y%m%d-%H%M%S")

def get_batch_creds(session, role_arn):
    sts = session.client("sts")
    response = sts.assume_role(
        RoleArn=role_arn,
        RoleSessionName=f"fpe-sagemaker-session--{timestamp()}"
    )
    return response['Credentials']

# session = boto3.Session(profile_name=AWS_PROFILE)

In [73]:
import sagemaker
creds = get_batch_creds(session, JOB_ROLE_ARN)
sm_boto_session = boto3.Session(
    aws_access_key_id=creds['AccessKeyId'],
    aws_secret_access_key=creds['SecretAccessKey'],
    aws_session_token=creds['SessionToken'],
    region_name=AWS_REGION
)

sm_session = sagemaker.Session(boto_session = sm_boto_session)

## Training

In [63]:
from sagemaker.pytorch import PyTorch

output_path = f"s3://{model_bucket}/rank/{SITE}/jobs/"
checkpoint_path = f"s3://{model_bucket}/models/{SITE}/checkpoints/"
estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    py_version="py38",
    framework_version="1.12",
    role="arn:aws:iam::694155575325:role/fpe-prod-sagemaker-execution-role",
    instance_count=1,
    # instance_type="ml.m5.large",
    instance_type="ml.p3.2xlarge",
    volume_size=100,
    hyperparameters={
        "data-file": "flow-images-train.csv",
        "num-image-stats": 1000,
        "num-train-pairs": 5000,
        "num-eval-pairs": 1000,
        "epochs": 15
    },
    base_job_name="fpe-rank",
    output_path=output_path,
    checkpoint_s3_uri=checkpoint_path,
    code_location=output_path[:-1],
    disable_output_compression=False,
    sagemaker_session=sm_session
)

In [64]:
from sagemaker.inputs import TrainingInput

data_dir = f"s3://{model_bucket}/rank/{SITE}/data"
images = TrainingInput(
    s3_data = f"{data_dir}/flow-images-train.manifest",
    s3_data_type = "ManifestFile",
    input_mode = "File"
)
values = f"{data_dir}/flow-images-train.csv"
(images, values)

(<sagemaker.inputs.TrainingInput at 0x1cb2cdc43a0>,
 's3://usgs-chs-conte-prod-fpe-models/rank/WESTKILL/data/flow-images-train.csv')

In [65]:
estimator.fit({ "images": images, "values": values }, wait=False)

Using provided s3_resource


INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: fpe-rank-2023-09-08-13-51-26-026


## Batch Transform

In [78]:
model_artifact_s3_location = f"s3://usgs-chs-conte-prod-fpe-models/rank/{SITE}/jobs/fpe-rank-2023-09-08-13-51-26-026/output/model.tar.gz"
model_artifact_s3_location

's3://usgs-chs-conte-prod-fpe-models/rank/WESTKILL/jobs/fpe-rank-2023-09-08-13-51-26-026/output/model.tar.gz'

In [75]:
from sagemaker.pytorch.model import PyTorchModel

pytorch_model = PyTorchModel(
    model_data=model_artifact_s3_location,
    role="arn:aws:iam::694155575325:role/fpe-prod-sagemaker-execution-role",
    py_version="py38",
    framework_version="1.12",
    source_dir="src/",
    entry_point="transform.py",
    sagemaker_session = sm_session
)
pytorch_model

<sagemaker.pytorch.model.PyTorchModel at 0x1cb22b7d940>

In [76]:
transformer = pytorch_model.transformer(
    instance_count=1,
    instance_type="ml.c5.xlarge",
    output_path=f"s3://{model_bucket}/rank/{SITE}/transform/output",
    job_name="job_name"
)
transformer

INFO:sagemaker:Repacking model artifact (s3://usgs-chs-conte-prod-fpe-models/rank/WESTKILL/jobs/fpe-rank-2023-09-08-13-51-26-026/output/model.tar.gz), script artifact (src/), and dependencies ([]) into single tar.gz file located at s3://sagemaker-us-west-2-694155575325/pytorch-inference-2023-09-08-18-45-42-638/model.tar.gz. This may take some time depending on model size...
INFO:sagemaker:Creating model with name: pytorch-inference-2023-09-08-18-46-19-875


<sagemaker.transformer.Transformer at 0x1cb2e08fe20>

In [77]:
transformer.transform(
    data=f"s3://{model_bucket}/rank/{SITE}/data/flow-images-train.manifest",
    data_type="ManifestFile",
    content_type="image/jpg",
    wait=False,
)

INFO:sagemaker:Creating transform job with name: pytorch-inference-2023-09-08-18-46-22-341


In [9]:
sm_session.stop_transform_job("pytorch-inference-2023-09-04-22-38-19-105")

INFO:sagemaker:Stopping transform job: pytorch-inference-2023-09-04-16-09-22-922


## Process Transform Output

In [81]:
data_file = f"rank/{SITE}/data/flow-images-train.csv"

# get number of rows
s3 = session.client('s3')

# Read CSV file from S3
print(f"downloading: s3://{model_bucket}/{data_file}")
csv_obj = s3.get_object(Bucket=model_bucket, Key=data_file)
csv_data = csv_obj['Body'].read().decode('utf-8')
data = pd.read_csv(StringIO(csv_data))
print(f"rows: {len(data)}")

downloading: s3://usgs-chs-conte-prod-fpe-models/rank/WESTKILL/data/flow-images-train.csv
rows: 4622


In [82]:
lambda_client = session.client("lambda")

job_size = 5000
skip = 0
while skip < len(data):
    payload = {
        "action": "process_transform_output",
        "bucket_name": model_bucket,
        "data_file": data_file,
        "data_prefix": f"rank/{SITE}/transform/output",
        "output_prefix": f"rank/{SITE}/transform",
        "n": job_size,
        "skip": skip
    }
    print(f"invoke: skip={skip}, n={job_size} ({skip} to {skip + job_size - 1})")
    lambda_client.invoke(
        FunctionName="fpe-prod-lambda-models",
        InvocationType="Event",
        Payload=json.dumps(payload)
    )
    skip += job_size

invoke: skip=0, n=5000 (0 to 4999)


In [83]:
filenames = [f"rank/{SITE}/transform/predictions-{skip:05d}-{(skip + job_size - 1):05d}.csv" for skip in range(0, len(data), job_size)]

dfs = []
for key in filenames:
    print(key)
    csv_obj = s3.get_object(Bucket=model_bucket, Key=key)
    csv_data = csv_obj['Body'].read().decode('utf-8')
    dfs.append(pd.read_csv(StringIO(csv_data)))

df = pd.concat(dfs, ignore_index=True)
df

rank/WESTKILL/transform/predictions-00000-04999.csv


Unnamed: 0,station_name,station_id,imageset_id,image_id,timestamp,filename,url,flow_cfs,score
0,01349711_West Kill,89,620,1064971,2022-10-13T15:00:00Z,imagesets/bba32ae8-7361-487a-a7f2-84b391e59fc9...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,4.52,-6.343492
1,01349711_West Kill,89,620,1064972,2022-10-13T16:00:00Z,imagesets/bba32ae8-7361-487a-a7f2-84b391e59fc9...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,4.52,-4.915704
2,01349711_West Kill,89,620,1064973,2022-10-13T17:00:00Z,imagesets/bba32ae8-7361-487a-a7f2-84b391e59fc9...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,4.71,-5.903940
3,01349711_West Kill,89,620,1064974,2022-10-13T18:00:00Z,imagesets/bba32ae8-7361-487a-a7f2-84b391e59fc9...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,4.91,-4.971095
4,01349711_West Kill,89,620,1064975,2022-10-13T19:00:00Z,imagesets/bba32ae8-7361-487a-a7f2-84b391e59fc9...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,5.79,-5.115305
...,...,...,...,...,...,...,...,...,...
4617,01349711_West Kill,89,1943,2706719,2023-06-28T16:30:00Z,imagesets/46b25625-a640-4795-9294-6fd3c5437273...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,7.59,-7.719587
4618,01349711_West Kill,89,1943,2706720,2023-06-28T16:45:00Z,imagesets/46b25625-a640-4795-9294-6fd3c5437273...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,8.11,-6.313333
4619,01349711_West Kill,89,1943,2706721,2023-06-28T17:00:00Z,imagesets/46b25625-a640-4795-9294-6fd3c5437273...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,8.37,-5.083445
4620,01349711_West Kill,89,1943,2706722,2023-06-28T17:15:00Z,imagesets/46b25625-a640-4795-9294-6fd3c5437273...,https://usgs-chs-conte-prod-fpe-storage.s3.ama...,8.37,-5.244941


In [84]:
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False)
output_key = f"rank/{SITE}/transform/predictions.csv"
s3.put_object(Body=csv_buffer.getvalue(), Bucket=model_bucket, Key=output_key)

{'ResponseMetadata': {'RequestId': 'J76KBHBWP0BF7WK1',
  'HostId': 'O24p24djPPJLTWfRS0VxbN5TKZA0WR7EeXVVvTzec885GLRQTWusLL4CikaLXxL/S/yPsmv+9o+no5N4ETBJ0Q==',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amz-id-2': 'O24p24djPPJLTWfRS0VxbN5TKZA0WR7EeXVVvTzec885GLRQTWusLL4CikaLXxL/S/yPsmv+9o+no5N4ETBJ0Q==',
   'x-amz-request-id': 'J76KBHBWP0BF7WK1',
   'date': 'Fri, 08 Sep 2023 19:55:43 GMT',
   'x-amz-server-side-encryption': 'AES256',
   'etag': '"345ad3df5dee0ffb6b6c5502562838f3"',
   'server': 'AmazonS3',
   'content-length': '0'},
  'RetryAttempts': 0},
 'ETag': '"345ad3df5dee0ffb6b6c5502562838f3"',
 'ServerSideEncryption': 'AES256'}