In [7]:
import boto3
import io
import urllib
import s3fs
import json
from pathlib import Path
import attr
import numpy
import tiledb
import tiledb.cloud
from tiledb.cloud.compute import DelayedArrayUDF, Delayed
import pandas
import geopandas
import fiona
from fiona.session import AWSSession
import pystac
from scipy.stats import skew, kurtosis
import uuid

In [8]:
import pystac
from pystac.extensions.projection import ProjectionExtension
from pystac.extensions.pointcloud import (
PointcloudExtension,
SchemaType,
PhenomenologyType,
Schema,
Statistic,
)

In [9]:
from reap_gsf import reap, data_model
from bathy_datasets import rhealpix, storage, geometry, asb_spreadsheet,stac_metadata

In [10]:
session = boto3.Session()
creds = session.get_credentials()

In [11]:
fs = s3fs.S3FileSystem(key=creds.access_key, secret=creds.secret_key, use_listings_cache=False)

In [12]:
uid = uuid.uuid4()

In [13]:
survey_uri = "s3://ausseabed-pl019-provided-data/DeakinUniversity/WilsonsPromontory_MNP/"
outdir_uri = "s3://ausseabed-pl019-ingested-data/L2/WilsonsPromontory_MNP/"
asb_metadata_uri = "s3://ausseabed-pl019-provided-data/DeakinUniversity/WilsonsPromontory_MNP/metadata/spreadsheet-metadata.json"
survey_info_uri = "s3://ausseabed-pl019-provided-data/DeakinUniversity/WilsonsPromontory_MNP/schema-info.json"

In [14]:
base_prefix = "ga_ausseabed"
array_name = f"{base_prefix}_{uid}_bathymetry"
array_uri = f"{outdir_uri}{array_name}.tiledb"
tiledb_array_uri = f"tiledb://sixy6e/{array_name}"
soundings_cell_density_uri = f"{outdir_uri}{base_prefix}_{uid}_soundings-cell-density-resolution-12.geojson"
coverage_uri = f"{outdir_uri}{base_prefix}_{uid}_coverage.geojson"
stac_md_uri = f"{outdir_uri}{base_prefix}_{uid}_stac-metadata.geojson"

In [15]:
soundings_cell_density_uri_15 = f"{outdir_uri}{base_prefix}_{uid}_soundings-cell-density-resolution-15.geojson"

In [16]:
def get_sonar_metadata(json_uri):
    """
    Temporary func for pulling metadata from a sample GSF file.
    """
    with fs.open(json_uri) as src:
        md = json.loads(src.read())
    stream_task = Delayed("sixy6e/retrieve_stream", name="retrieve")(md["gsf_uri"], creds.access_key, creds.secret_key)
    dataframe_task = Delayed("sixy6e/decode_gsf", name="decode", image_name="3.7-geo")(stream_task, slice(10))
    df, finfo = dataframe_task.compute()
    sonar_metadata = finfo[3].record(0).read(stream_task.result()[0])
    history = attr.asdict(finfo[6].record(0).read(stream_task.result()[0]))
    for key, value in history.items():
        sonar_metadata[key] = value
    return sonar_metadata


def reduce_region_codes(results):
    """
    The reduce part of the map-reduce construct for handling the region_code counts.
    Combine all the region_code counts then summarise the results.
    """
    region_codes = [i[0] for i in results]
    timestamps = [i[1] for i in results]
    df = pandas.concat(region_codes)
    cell_count = df.groupby(["region_code"])["count"].agg("sum").to_frame("count").reset_index()
    
    timestamps_df = pandas.DataFrame(
        {
            "start_datetime": [i[0] for i in timestamps],
            "end_datetime": [i[1] for i in timestamps],
        }
    )

    start_end_timestamp = [
        timestamps_df.start_datetime.min().to_pydatetime(),
        timestamps_df.end_datetime.max().to_pydatetime(),
    ]

    return cell_count, start_end_timestamp


def gather_stats(results):
    """
    Gather the results from all the stats tasks and
    combine into a single dict.
    """
    data = {}
    for item in results:
        for key in item:
            data[key] = item[key]
    return data

In [17]:
def retrieve_stream(uri, access_key, skey):
    """
    Not testing the creation of the stream object at this point.
    But for testing, we also need to keep the download to occur only
    once.
    """
    session = boto3.Session(aws_access_key_id=access_key, aws_secret_access_key=skey)
    dev_resource = session.resource("s3")
    uri = urllib.parse.urlparse(uri)
    obj = dev_resource.Object(bucket_name=uri.netloc, key=uri.path[1:])
    stream = io.BytesIO(obj.get()["Body"].read())
    return stream, obj.content_length


def append_ping_dataframe(dataframe, array_uri, access_key, skey):
    """Append the ping dataframe read from a GSF file."""
    config = tiledb.Config(
        {"vfs.s3.aws_access_key_id": access_key, "vfs.s3.aws_secret_access_key": skey}
    )
    ctx = tiledb.Ctx(config=config)
    kwargs = {
        "mode": "append",
        "sparse": True,
        "ctx": ctx,
    }

    tiledb.dataframe_.from_pandas(array_uri, dataframe, **kwargs)


def ingest_gsf_slice(
    file_record, stream, access_key, skey, array_uri, idx=slice(None)
):
    """
    General steps:
    Extract the ping data.
    Calculate the rHEALPIX code.
    Summarise the rHEALPIX codes (frequency count).
    Get timestamps of first and last pings.
    Write the ping data to a TileDB array.
    res = [df.groupby(["key"])["key"].agg("count").to_frame("count").reset_index() for i in range(3)]
    df2 = pandas.concat(res)
    df2.groupby(["key"])["count"].agg("sum")
    """
    swath_pings = data_model.SwathBathymetryPing.from_records(file_record, stream, idx)
    swath_pings.ping_dataframe["region_code"] = rhealpix.rhealpix_code(
        swath_pings.ping_dataframe.X, swath_pings.ping_dataframe.Y, 15
    )

    # frequency of dggs cells
    cell_count = (
        swath_pings.ping_dataframe.groupby(["region_code"])["region_code"]
        .agg("count")
        .to_frame("count")
        .reset_index()
    )

    start_end_time = [
        swath_pings.ping_dataframe.timestamp.min().to_pydatetime(),
        swath_pings.ping_dataframe.timestamp.max().to_pydatetime(),
    ]

    # write to tiledb array
    append_ping_dataframe(swath_pings.ping_dataframe, array_uri, access_key, skey)

    return cell_count, start_end_time


def ingest_gsf_slices(gsf_uri, access_key, skey, array_uri, slices):
    """
    Ingest a list of ping slices from a given GSF file.
    """
    stream, stream_length = retrieve_stream(gsf_uri, access_key, skey)
    finfo = reap.file_info(stream, stream_length)
    ping_file_record = finfo[1]

    cell_counts = []
    start_end_timestamps = []

    for idx in slices:
        count, start_end_time = ingest_gsf_slice(
            ping_file_record, stream, access_key, skey, array_uri, idx
        )
        cell_counts.append(count)
        start_end_timestamps.append(start_end_time)

    # aggreate the ping slices and calculate the cell counts
    concatenated = pandas.concat(cell_counts)
    cell_count = (
        concatenated.groupby(["region_code"])["count"]
        .agg("sum")
        .to_frame("count")
        .reset_index()
    )

    # aggregate the min and max timestamps, then find the min max timestamps
    timestamps_df = pandas.DataFrame(
        {
            "start_datetime": [i[0] for i in start_end_timestamps],
            "end_datetime": [i[1] for i in start_end_timestamps],
        }
    )

    start_end_timestamp = [
        timestamps_df.start_datetime.min().to_pydatetime(),
        timestamps_df.end_datetime.max().to_pydatetime(),
    ]

    return cell_count, start_end_timestamp

In [18]:
def scatter(iterable, n):
    """
    Evenly scatters an interable by `n` blocks.
    Sourced from:
    http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts

    :param iterable:
        An iterable or preferably a 1D list or array.

    :param n:
        An integer indicating how many blocks to create.

    :return:
        A `list` consisting of `n` blocks of roughly equal size, each
        containing elements from `iterable`.
    """

    q, r = len(iterable) // n, len(iterable) % n
    res = (iterable[i * q + min(i, r) : (i + 1) * q + min(i + 1, r)] for i in range(n))
    return list(res)

In [19]:
def ingest_gsfs(files, size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node):
    """
    Prototype ingester.
    """

    node_counter = 0
    skipped_files = []
    large_files = []
    tasks = []
    tasks_dict = {n: [] for n in range(processing_node_limit)}

    for pathname in files:
        metadata_pathname = pathname.replace(".gsf", ".json")
        base_name = Path(pathname).stem
        with fs.open(metadata_pathname) as src:
            gsf_metadata = json.loads(src.read())

        if (gsf_metadata["size"] / 1024 / 1024) > size_limit_mb:
            large_files.append(pathname)
            continue

        ping_count = gsf_metadata["file_record_types"]["GSF_SWATH_BATHYMETRY_PING"]["record_count"]
        if ping_count == 0:
            skipped_files.append(pathname)
            continue

        slices = [slice(start, start+ping_slice_step) for start in numpy.arange(0, ping_count, ping_slice_step)]
        slice_chunks = [slices[i:i+slices_per_node] for i in range(0, len(slices), slices_per_node)]

        for slice_chunk in slice_chunks:
            start_idx = slice_chunk[0].start
            end_idx = slice_chunk[0].stop
            task_name = f"{base_name}-{start_idx}-{end_idx}-{node_counter}"
            task = Delayed("sixy6e/ingest_gsf_slices", name=task_name, image_name="3.7-geo")(gsf_metadata["gsf_uri"], creds.access_key, creds.secret_key, array_uri, slice_chunk)
            task.set_timeout(1800)

            if len(tasks_dict[node_counter]):
                task.depends_on(tasks_dict[node_counter][-1])

            tasks.append(task)
            tasks_dict[node_counter].append(task)
            node_counter += 1

            if node_counter == processing_node_limit:
                node_counter = 0

    reduce_task = Delayed(reduce_region_codes, "reduce-region_codes-timestamps", local=True)(tasks)
    
    return reduce_task, skipped_files, large_files

In [20]:
def ingest_gsfs_local(files, size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node):
    """
    Prototype ingester.
    """

    node_counter = 0
    skipped_files = []
    large_files = []
    tasks = []
    tasks_dict = {n: [] for n in range(processing_node_limit)}

    for pathname in files:
        metadata_pathname = pathname.replace(".gsf", ".json")
        base_name = Path(pathname).stem
        with fs.open(metadata_pathname) as src:
            gsf_metadata = json.loads(src.read())

        #if (gsf_metadata["size"] / 1024 / 1024) > size_limit_mb:
        #    large_files.append(pathname)
        #    continue

        ping_count = gsf_metadata["file_record_types"]["GSF_SWATH_BATHYMETRY_PING"]["record_count"]
        if ping_count == 0:
            skipped_files.append(pathname)
            continue

        slices = [slice(start, start+ping_slice_step) for start in numpy.arange(0, ping_count, ping_slice_step)]
        slice_chunks = [slices[i:i+slices_per_node] for i in range(0, len(slices), slices_per_node)]

        for slice_chunk in slice_chunks:
            start_idx = slice_chunk[0].start
            end_idx = slice_chunk[0].stop
            task_name = f"{base_name}-{start_idx}-{end_idx}-{node_counter}"
            task = Delayed(ingest_gsf_slices, name=task_name, local=True)(gsf_metadata["gsf_uri"], creds.access_key, creds.secret_key, array_uri, slice_chunk)
            task.set_timeout(1800)

            if len(tasks_dict[node_counter]):
                task.depends_on(tasks_dict[node_counter][-1])

            tasks.append(task)
            tasks_dict[node_counter].append(task)
            node_counter += 1

            if node_counter == processing_node_limit:
                node_counter = 0

    reduce_task = Delayed(reduce_region_codes, "reduce-region_codes-timestamps", local=True)(tasks)
    
    return reduce_task, skipped_files, large_files

In [21]:
with fs.open(survey_info_uri) as src:
    survey_info = json.loads(src.read())

In [22]:
#required_attributes = survey_info["schemas"][0]
# this is temporary. better to have it defined internally. or programmatically derived as a union of all schemas from all pings
required_attributes = [
    "Z",
    "across_track",
    "along_track",
    "beam_angle",
    "beam_angle_forward",
    "beam_flags",
    "beam_number",
    "centre_beam",
    "course",
    "depth_corrector",
    "gps_tide_corrector",
    "heading",
    "heave",
    "height",
    "horizontal_error",
    "ping_flags",
    "pitch",
    "roll",
    "sector_number",
    "separation",
    "speed",
    "tide_corrector",
    "timestamp",
    "travel_time",
    "vertical_error",
    "region_code",
]

In [23]:
config = tiledb.Config(
        {"vfs.s3.aws_access_key_id": creds.access_key, "vfs.s3.aws_secret_access_key": creds.secret_key}
    )
config_dict = config.dict()
ctx = tiledb.Ctx(config=config)

In [None]:
storage.create_mbes_array(array_uri, required_attributes, ctx)

In [24]:
files = fs.glob(survey_uri + "**.gsf")
len(files)

705

In [25]:
sonar_metadata = get_sonar_metadata(files[0].replace(".gsf", ".json"))

In [26]:
n_partitions = 8
files_blocks = scatter(files, n_partitions)
len(files_blocks[0])

89

In [27]:
size_limit_mb = 500
processing_node_limit = 5
ping_slice_step = 2000
slices_per_node = 3
local_tasks_limit = 1
local_ping_slice_step = 2000
local_slices_per_task = 4

In [28]:
skipped_files = []
large_files = []

In [None]:
reduce_task, skipped_files1, large_files1 = ingest_gsfs(files_blocks[0], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df1, start_end_timestamps1 = reduce_task.compute()

In [None]:
reduce_task, skipped_files2, large_files2 = ingest_gsfs(files_blocks[1], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df2, start_end_timestamps2 = reduce_task.compute()

In [None]:
reduce_task, skipped_files3, large_files3 = ingest_gsfs(files_blocks[2], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df3, start_end_timestamps3 = reduce_task.compute()

In [None]:
reduce_task, skipped_files4, large_files4 = ingest_gsfs(files_blocks[2], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df4, start_end_timestamps4 = reduce_task.compute()

In [None]:
reduce_task, skipped_files5, large_files5 = ingest_gsfs(files_blocks[4], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df5, start_end_timestamps5 = reduce_task.compute()

In [None]:
reduce_task, skipped_files6, large_files6 = ingest_gsfs(files_blocks[5], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df6, start_end_timestamps6 = reduce_task.compute()

In [None]:
reduce_task, skipped_files7, large_files7 = ingest_gsfs(files_blocks[6], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df7, start_end_timestamps7 = reduce_task.compute()

In [None]:
reduce_task, skipped_files8, large_files8 = ingest_gsfs(files_blocks[7], size_limit_mb, processing_node_limit, ping_slice_step, slices_per_node)

In [None]:
cell_count_df8, start_end_timestamps8 = reduce_task.compute()

In [None]:
large_files.extend(large_files1)
large_files.extend(large_files2)
large_files.extend(large_files3)
large_files.extend(large_files4)
large_files.extend(large_files5)
large_files.extend(large_files6)
large_files.extend(large_files7)
large_files.extend(large_files8)

In [None]:
skipped_files.extend(skipped_files1)
skipped_files.extend(skipped_files2)
skipped_files.extend(skipped_files3)
skipped_files.extend(skipped_files4)
skipped_files.extend(skipped_files5)
skipped_files.extend(skipped_files6)
skipped_files.extend(skipped_files7)
skipped_files.extend(skipped_files8)

In [None]:
len(large_files)

In [None]:
len(skipped_files)

In [None]:
reduce_task, skipped_files_local, large_files_local = ingest_gsfs_local(large_files, size_limit_mb, local_tasks_limit, local_ping_slice_step, local_slices_per_task)

In [None]:
reduce_task.visualize()

In [None]:
cell_count_df_local, start_end_timestamps_local = reduce_task.compute()

In [None]:
# collect and reduce dataframes and ping start end times

In [None]:
local_non_local_results = [
    [cell_count_df1, start_end_timestamps1],
    [cell_count_df2, start_end_timestamps2],
    [cell_count_df3, start_end_timestamps3],
    [cell_count_df4, start_end_timestamps4],
    [cell_count_df5, start_end_timestamps5],
    [cell_count_df6, start_end_timestamps6],
    [cell_count_df7, start_end_timestamps7],
    [cell_count_df8, start_end_timestamps8],
    [cell_count_df_local, start_end_timestamps_local],
]

In [None]:
final_cell_count_df, final_start_end_timestamps = reduce_region_codes(local_non_local_results)

In [None]:
final_cell_count_df

In [None]:
final_start_end_timestamps

In [None]:
final_cell_count_df["geometry"] = rhealpix.rhealpix_geo_boundary(final_cell_count_df.region_code.values)

In [None]:
gdf_15 = geopandas.GeoDataFrame(final_cell_count_df, crs="epsg:4326")

In [None]:
with fiona.Env(session=AWSSession(aws_access_key_id=creds.access_key, aws_secret_access_key=creds.secret_key)):
    gdf_15.to_file(soundings_cell_density_uri_15, driver="GeoJSONSeq", coordinate_precision=11)

In [None]:
resolution12_df = pandas.DataFrame(
    {
        "region_code": final_cell_count_df.region_code.str[0:13],
        "count": final_cell_count_df["count"],
    }
).groupby(
    ["region_code"]
)["count"].agg("sum").to_frame("count").reset_index()

In [None]:
resolution12_df

In [None]:
resolution12_df["geometry"] = rhealpix.rhealpix_geo_boundary(resolution12_df.region_code.values)

In [None]:
gdf = geopandas.GeoDataFrame(resolution12_df, crs="epsg:4326")

In [None]:
with fiona.Env(session=AWSSession(aws_access_key_id=creds.access_key, aws_secret_access_key=creds.secret_key)):
    gdf.to_file(soundings_cell_density_uri, driver="GeoJSONSeq", coordinate_precision=11)

In [None]:
dissolved = geopandas.GeoDataFrame(geometry.dissolve(gdf), crs="epsg:4326")

In [None]:
with fiona.Env(session=AWSSession(aws_access_key_id=creds.access_key, aws_secret_access_key=creds.secret_key)):
    dissolved.to_file(coverage_uri, driver="GeoJSONSeq", coordinate_precision=11)

In [None]:
dggs = rhealpix.RhealpixDGGS.from_ellipsoid()

In [None]:
dggs.cell_width(12)

In [None]:
area_ha = gdf.shape[0] * dggs.cell_width(12) **2 / 10000
sonar_metadata["area_ha"] = area_ha
area_ha

In [None]:
with fs.open(asb_metadata_uri) as src:
    asb_metadata = json.loads(src.read())

In [None]:
tiledb.cloud.register_array(
    uri=array_uri,
    namespace="sixy6e", # Optional, you may register it under your username, or one of your organizations
    array_name=array_name,
    description=asb_metadata["survey_general"]["abstract"],  # Optional 
    access_credentials_name="AusSeabedGMRT-PL019"
)

In [None]:
with tiledb.open(array_uri, ctx=ctx) as ds:
    schema = ds.schema
    domain = ds.domain
    non_empty_domain = ds.nonempty_domain()

In [None]:
gdf["count"].max()

In [None]:
gdf["count"].min()

In [None]:
full_idx = (slice(*non_empty_domain[0]), slice(*non_empty_domain[1]))
full_idx

In [None]:
# test first to see if stats can be generated with full domain, or if we need to iterate over region codes
# use the X from the schema. This should use the most memory. if it fails then use the scatter approach

In [None]:
# task = Delayed("sixy6e/basic_statistics_incremental", name="test-X-stat-full-idx")(array_uri, config_dict, "X", schema="X", idxs=[full_idx], summarise=True)
# result = task.compute()

In [None]:
# reduce the region code resolution

In [None]:
gdf2 = geopandas.GeoDataFrame({"region_code": gdf.region_code.str[0:11], "count": gdf["count"]}).groupby(["region_code"])["count"].agg("sum").to_frame("count").reset_index()

In [None]:
gdf2

In [None]:
slices = []
for geom in rhealpix.rhealpix_geo_boundary(gdf2.region_code.values, round_coords=False):
    bounds = geom.bounds
    slices.append((
        slice(bounds[0], bounds[-2]),
        slice(bounds[1], bounds[-1])
    ))

In [None]:
n_partitions = 2
n_sub_partitions = 2
blocks = scatter(slices, n_partitions)

In [None]:
len(blocks), len(blocks[0])

In [None]:
len(scatter(blocks[0], n_sub_partitions)[0])

In [None]:
stats_attrs = [at for at in required_attributes if at not in ["timestamp", "region_code"]]
stats_attrs.insert(0, "Y")
stats_attrs.insert(0, "X")

In [None]:
stats_results = []
tasks_dict = {stat: [] for stat in stats_attrs}
reduce_tasks = []

for i, block in enumerate(blocks):
    sub_tasks = []
    sub_blocks = scatter(block, n_sub_partitions)

    for si, sub_block in enumerate(sub_blocks):
        for attribute in stats_attrs:
            
            if attribute in ["X", "Y"]:
                schema = attribute
            else:
                schema = None

            task_name = f"block-{i}-sub_block-{si}-{attribute}"
            task = Delayed("sixy6e/basic_statistics_incremental", name=task_name)(array_uri, config_dict, attribute, schema=schema, idxs=sub_block, summarise=False)

            if len(tasks_dict[attribute]) > 1:
                task.depends_on(tasks_dict[attribute][-1])

            tasks_dict[attribute].append(task)

for attribute in stats_attrs:
    task_name = f"reduce-attibute-{attribute}"
    reducer_task = Delayed("sixy6e/basic_statistics_reduce", name=task_name)(tasks_dict[attribute], attribute)
    reduce_tasks.append(reducer_task)

collect_stats_task = Delayed(gather_stats, local=True, name="gather-stats")(reduce_tasks)

In [None]:
stats_results = collect_stats_task.compute()

In [None]:
# check the vertical datum

In [None]:
asb_metadata

In [None]:
crs_info = {
    "horizontal_datum": "epsg:4326",
    "vertical_data": "epsg:4326",
}

In [None]:
with tiledb.open(array_uri, "w", ctx=ctx) as ds:
    ds.meta["crs_info"] = json.dumps(crs_info)
    ds.meta["basic_statistics"] = json.dumps(stats_results, cls=stac_metadata.Encoder)

In [None]:
# produce stac metadata

In [None]:
dataset_metadata = stac_metadata.prepare(
    uid,
    sonar_metadata,
    stats_results,
    asb_metadata,
    array_uri,
    coverage_uri,
    soundings_cell_density_uri,
    creds.access_key,
    creds.secret_key,
    final_start_end_timestamps,
    outdir_uri,
    stac_md_uri,
)