# Merging Sessions Script

This script is used to merge two or more sessions, provided they do not contain overlapping regions of interest (ROIs).

### Prerequisites:
- Paths to two session directories with extracted shorelines.
- The desired name for the merged session directory that will be saved in the `sessions` directory.

### Optional:
- A `config.json` file with transect settings for calculating shoreline-transect intersections.

### Instructions:
1. Enter the paths to the session directories below:
    ``` python
   session_locations=[
       '<path_to_first_session_directory>',
       '<path_to_second_session_directory>'
      ]
    ```
   Example:
   - Notice that because these are Windows locations we put `r` at the beginning of each location
    ``` python
   session_locations=[
      r'C:\development\doodleverse\coastseg\CoastSeg\sessions\es1\ID_13_datetime06-05-23__04_16_45',
      r'C:\development\doodleverse\coastseg\CoastSeg\sessions\es1\ID_12_datetime06-05-23__04_16_45'
      ]
    ```
2. Specify the name for the merged session directory:
   - `merged_session_directory`: `"<name_of_merged_session_directory>"`

3. (Optional) If you want to use your own advanced settings in a `config.json` file, include its path:
   - `config_file`: `"<path_to_config_json>"`

With the above information, the script can be executed to merge the specified sessions into a single session directory.


In [None]:
# replace these with the ROI directories from your own extract shorelines sessions

session_locations=[r'C:\development\doodleverse\coastseg\CoastSeg\sessions\ID_rrw15_datetime11-21-23__11_32_09\ID_rrw15_datetime11-21-23__11_32_09',
                   r'C:\development\doodleverse\coastseg\CoastSeg\sessions\ID_rrw15_datetime11-21-23__11_35_25_es3\ID_rrw15_datetime11-21-23__11_35_25']


In [None]:
merged_session_directory='merged_session_name'

## Create the merged session diretory under sessions

In [None]:
import os
# enter the location of your sessions directory if this is not correct
sessions_directory = os.path.join(os.getcwd(), 'sessions')
print(sessions_directory)
merged_session_location = os.path.join(sessions_directory, merged_session_directory)
os.makedirs(merged_session_location, exist_ok=True)

print(f"Merged session will be saved to {merged_session_location}")

## Shoreline-Transect Intersection Analysis Settings

The default settings listed below should suffice for most use cases to find where extracted shorelines intersect transects. However, if you modified the advanced settings then you will need to adjust the settings.


Most users will want to just use the default settings listed below.

In [None]:
settings_transects ={
            "along_dist": 25,  # along-shore distance to use for computing the intersection
            "min_points": 3,  # minimum number of shoreline points to calculate an intersection
            "max_std": 15,  # max std for points around transect
            "max_range": 30,  # max range for points around transect
            "min_chainage": -100,  # largest negative value along transect (landwards of transect origin)
            "multiple_inter": "auto",  # mode for removing outliers ('auto', 'nan', 'max')
            "prc_multiple": 0.1,  # percentage of the time that multiple intersects are present to use the max
}

In [None]:
# Standard library imports
from collections import defaultdict
import os
from typing import List, Optional, Union

# Related third party imports
import geopandas as gpd
import numpy as np
import pandas as pd
from shapely.geometry import LineString, MultiLineString, MultiPoint, Point
from shapely.ops import unary_union

# Local application/library specific imports
from coastseg import geodata_processing


def convert_multipoints_to_linestrings(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    """
    Convert MultiPoint geometries in a GeoDataFrame to LineString geometries.

    Args:
    - gdf (gpd.GeoDataFrame): The input GeoDataFrame.

    Returns:
    - gpd.GeoDataFrame: A new GeoDataFrame with LineString geometries. If the input GeoDataFrame
                        already contains LineStrings, the original GeoDataFrame is returned.
    """

    # Create a copy of the GeoDataFrame
    gdf_copy = gdf.copy()

    # Check if all geometries in the gdf are LineStrings
    if all(gdf_copy.geometry.type == "LineString"):
        return gdf_copy

    def multipoint_to_linestring(multipoint):
        if isinstance(multipoint, MultiPoint):
            return LineString(multipoint.geoms)
        return multipoint

    # Convert each MultiPoint to a LineString
    gdf_copy["geometry"] = gdf_copy["geometry"].apply(multipoint_to_linestring)

    return gdf_copy


def dataframe_to_dict(df: pd.DataFrame, key_map: dict) -> dict:
    """
    Converts a DataFrame to a dictionary, with specific mapping between dictionary keys and DataFrame columns.

    Parameters:
    df : DataFrame
        The DataFrame to convert.
    key_map : dict
        A dictionary where keys are the desired dictionary keys and values are the corresponding DataFrame column names.

    Returns:
    dict
        The resulting dictionary.
    """
    result_dict = defaultdict(list)

    for dict_key, df_key in key_map.items():
        if df_key in df.columns:
            if df_key == "date":
                # Assumes the column to be converted to date is the one specified in the mapping with key 'date'
                result_dict[dict_key] = list(
                    df[df_key].apply(
                        lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
                        if pd.notnull(x)
                        else None
                    )
                )
            elif df_key == "geometry":
                # Assumes the column to be converted to geometry is the one specified in the mapping with key 'geometry'
                result_dict[dict_key] = list(
                    df[df_key].apply(
                        lambda x: np.array([list(point.coords[0]) for point in x.geoms])
                        if pd.notnull(x)
                        else None
                    )
                )
            else:
                result_dict[dict_key] = list(df[df_key])

    return dict(result_dict)


def convert_lines_to_multipoints(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    """
    Convert LineString or MultiLineString geometries in a GeoDataFrame to MultiPoint geometries.

    Parameters
    ----------
    gdf : GeoDataFrame
        The input GeoDataFrame containing LineString or MultiLineString geometries.

    Returns
    -------
    GeoDataFrame
        A new GeoDataFrame with MultiPoint geometries.

    """
    # Create a copy of the input GeoDataFrame to avoid modifying it in place
    gdf = gdf.copy()

    # Define a function to convert LineString or MultiLineString to MultiPoint
    def line_to_multipoint(geometry):
        if isinstance(geometry, LineString):
            return MultiPoint(geometry.coords)
        elif isinstance(geometry, MultiLineString):
            points = [MultiPoint(line.coords) for line in geometry.geoms]
            return MultiPoint([point for multi in points for point in multi.geoms])
        elif isinstance(geometry, MultiPoint):
            return geometry
        elif isinstance(geometry, Point):
            return MultiPoint([geometry.coords])
        else:
            raise TypeError(f"Unsupported geometry type: {type(geometry)}")

    # Apply the conversion function to each row in the GeoDataFrame
    gdf["geometry"] = gdf["geometry"].apply(line_to_multipoint)

    return gdf


def read_first_geojson_file(
    directory: str,
    filenames=["extracted_shorelines_lines.geojson", "extracted_shorelines.geojson"],
):
    # Loop over the filenames
    for filename in filenames:
        filepath = os.path.join(directory, filename)

        # If the file exists, read it and return the GeoDataFrame
        if os.path.exists(filepath):
            return geodata_processing.read_gpd_file(filepath)

    # If none of the files exist, raise an exception
    raise FileNotFoundError(
        f"None of the files {filenames} exist in the directory {directory}"
    )


def clip_gdfs(gdfs, overlap_gdf):
    """
    Clips GeoDataFrames to an overlapping region.

    Parameters:
    gdfs : list of GeoDataFrames
        The GeoDataFrames to be clipped.
    overlap_gdf : GeoDataFrame
        The overlapping region to which the GeoDataFrames will be clipped.

    Returns:
    list of GeoDataFrames
        The clipped GeoDataFrames.
    """
    clipped_gdfs = []
    for gdf in gdfs:
        clipped_gdf = gpd.clip(gdf, overlap_gdf)
        if not clipped_gdf.empty:
            clipped_gdfs.append(clipped_gdf)
            clipped_gdf.plot()
    return clipped_gdfs


def calculate_overlap(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    """
    Calculates the intersection of all pairs of polygons in a GeoDataFrame.

    Parameters:
    -----------
    gdf : GeoDataFrame
        A GeoDataFrame containing polygons.

    Returns:
    --------
    overlap_gdf : GeoDataFrame
        A GeoDataFrame containing the intersection of all pairs of polygons in gdf.
    """
    # Check if the input GeoDataFrame is empty
    if not hasattr(gdf, "empty"):
        return gpd.GeoDataFrame()
    if gdf.empty:
        # Return an empty GeoDataFrame with the same CRS if it exists
        return gpd.GeoDataFrame(
            geometry=[], crs=gdf.crs if hasattr(gdf, "crs") else None
        )

    # Initialize a list to store the intersections
    intersections = []

    # Loop over each pair of rows in gdf
    for i in range(len(gdf) - 1):
        for j in range(i + 1, len(gdf)):
            # Check for intersection
            if gdf.iloc[i].geometry.intersects(gdf.iloc[j].geometry):
                # Calculate the intersection
                intersection = gdf.iloc[i].geometry.intersection(gdf.iloc[j].geometry)
                # Append the intersection to the intersections list
                intersections.append(intersection)

    # Create a GeoSeries from the intersections
    intersection_series = gpd.GeoSeries(intersections, crs=gdf.crs)

    # Create a GeoDataFrame from the GeoSeries
    overlap_gdf = gpd.GeoDataFrame(geometry=intersection_series)
    return overlap_gdf


def average_multipoints(multipoints) -> MultiPoint:
    """
    Calculate the average MultiPoint geometry from a list of MultiPoint geometries.

    This function takes a list of shapely MultiPoint geometries, ensures they all have the same number of points
    by padding shorter MultiPoints with their last point, and then calculates the average coordinates
    for each point position across all the input MultiPoint geometries.

    The result is a new MultiPoint geometry that represents the average shape of the input MultiPoints.

    Parameters:
    multipoints (list of shapely.geometry.MultiPoint): A list of shapely MultiPoint geometries to be averaged.

    Returns:
    shapely.geometry.MultiPoint: A MultiPoint geometry representing the average shape of the input MultiPoints.

    Raises:
    ValueError: If the input list of MultiPoint geometries is empty.

    Example:
    >>> from shapely.geometry import MultiPoint
    >>> multipoint1 = MultiPoint([(0, 0), (1, 1), (2, 2)])
    >>> multipoint2 = MultiPoint([(1, 1), (2, 2)])
    >>> multipoint3 = MultiPoint([(0, 0), (1, 1), (2, 2), (3, 3)])
    >>> average_mp = average_multipoints([multipoint1, multipoint2, multipoint3])
    >>> print(average_mp)
    MULTIPOINT (0.3333333333333333 0.3333333333333333, 1.3333333333333333 1.3333333333333333, 2 2, 3 3)
    """
    if not multipoints:
        raise ValueError("The list of MultiPoint geometries is empty")

    # Find the maximum number of points in any MultiPoint
    max_len = max(len(mp.geoms) for mp in multipoints)

    # Pad shorter MultiPoints with their last point
    padded_multipoints = []
    for mp in multipoints:
        if len(mp.geoms) < max_len:
            padded_multipoints.append(
                MultiPoint(list(mp.geoms) + [mp.geoms[-1]] * (max_len - len(mp.geoms)))
            )
        else:
            padded_multipoints.append(mp)

    # Calculate the average coordinates for each point
    num_multipoints = len(padded_multipoints)
    average_coords = []
    for i in range(max_len):
        avg_left = sum(mp.geoms[i].x for mp in padded_multipoints) / num_multipoints
        avg_right = sum(mp.geoms[i].y for mp in padded_multipoints) / num_multipoints
        average_coords.append((avg_left, avg_right))

    return MultiPoint(average_coords)


def merge_geometries(merged_gdf, columns=None, operation=unary_union):
    """
    Performs a specified operation for the geometries with the same date and satname.

    Parameters:
    merged_gdf : GeoDataFrame
        The GeoDataFrame to perform the operation on.
    columns : list of str, optional
        The columns to perform the operation on. If None, all columns with 'geometry' in the name are used.
    operation : function, optional
        The operation to perform. If None, unary_union is used.

    Returns:
    GeoDataFrame
        The GeoDataFrame with the operation performed.
    """
    if columns is None:
        columns = [col for col in merged_gdf.columns if "geometry" in col]
    else:
        columns = [col for col in columns if col in merged_gdf.columns]

    merged_gdf["geometry"] = merged_gdf[columns].apply(
        lambda row: operation(row.tolist()), axis=1
    )
    for col in columns:
        if col in merged_gdf.columns and col != "geometry":
            merged_gdf = merged_gdf.drop(columns=col)
    return merged_gdf


def read_geojson_files(filepaths):
    """Read GeoJSON files into GeoDataFrames and return a list."""
    return [gpd.read_file(path) for path in filepaths]


def concatenate_gdfs(gdfs):
    """Concatenate a list of GeoDataFrames into a single GeoDataFrame."""
    return pd.concat(gdfs, ignore_index=True)


def filter_and_join_gdfs(gdf, feature_type, predicate="intersects"):
    """Filter GeoDataFrame by feature type, ensure spatial index, and perform a spatial join."""
    if "type" not in gdf.columns:
        raise ValueError("The GeoDataFrame must contain a column named 'type'")
    filtered_gdf = gdf[gdf["type"] == feature_type].copy()[["geometry"]]
    filtered_gdf["geometry"] = filtered_gdf["geometry"].simplify(
        tolerance=0.001
    )  # Simplify geometry if possible to improve performance
    filtered_gdf.sindex  # Ensure spatial index
    return gpd.sjoin(gdf, filtered_gdf[["geometry"]], how="inner", predicate=predicate)


def aggregate_gdf(gdf: gpd.GeoDataFrame, group_fields: list) -> gpd.GeoDataFrame:
    """
    Aggregate a GeoDataFrame by specified fields using a custom combination function.

    Parameters:
        gdf (GeoDataFrame): The input GeoDataFrame to be aggregated.
        group_fields (list): The fields to group the GeoDataFrame by.

    Returns:
        GeoDataFrame: The aggregated GeoDataFrame.
    """

    def combine_non_nulls(series):
        unique_values = series.dropna().unique()
        return (
            unique_values[0]
            if len(unique_values) == 1
            else ", ".join(map(str, unique_values))
        )

    if "index_right" in gdf.columns:
        gdf = gdf.drop(columns=["index_right"])

    return (
        gdf.drop_duplicates()
        .groupby(group_fields, as_index=False)
        .agg(combine_non_nulls)
    )


def merge_geojson_files(session_locations, merged_session_location):
    """Main function to merge GeoJSON files from different session locations."""
    filepaths = [
        os.path.join(location, "config_gdf.geojson") for location in session_locations
    ]
    gdfs = read_geojson_files(filepaths)
    merged_gdf = gpd.GeoDataFrame(concatenate_gdfs(gdfs), geometry="geometry")

    # Filter the geodataframe to only elements that intersect with the rois (dramatically drops the size of the geodataframe)
    merged_config = filter_and_join_gdfs(merged_gdf, "roi", predicate="intersects")
    # apply a group by operation to combine the rows with the same type and geometry into a single row
    merged_config = aggregate_gdf(merged_config, ["type", "geometry"])
    # applying the group by function in aggregate_gdf() turns the geodataframe into a dataframe
    merged_config = gpd.GeoDataFrame(merged_config, geometry="geometry")

    output_path = os.path.join(merged_session_location, "merged_config.geojson")
    merged_config.to_file(output_path, driver="GeoJSON")

    return merged_config


def create_csv_per_transect(
    save_path: str,
    cross_distance_transects: dict,
    extracted_shorelines_dict: dict,
    roi_id: str = None,  # ROI ID is now optional and defaults to None
    filename_suffix: str = "_timeseries_raw.csv",
):
    for key, distances in cross_distance_transects.items():
        # Initialize the dictionary for DataFrame with mandatory keys
        data_dict = {
            "dates": extracted_shorelines_dict["dates"],
            "satname": extracted_shorelines_dict["satname"],
            key: distances,
        }

        # Add roi_id to the dictionary if provided
        if roi_id is not None:
            data_dict["roi_id"] = [roi_id] * len(extracted_shorelines_dict["dates"])

        # Create a DataFrame directly with the data dictionary
        df = pd.DataFrame(data_dict).set_index("dates")

        # Construct the full file path
        csv_filename = f"{key}{filename_suffix}"
        fn = os.path.join(save_path, csv_filename)

        # Save to CSV file, 'mode' set to 'w' for overwriting
        try:
            df.to_csv(fn, sep=",", mode="w")
            print(f"Time-series for transect {key} saved to {fn}")
        except Exception as e:
            print(f"Failed to save time-series for transect {key}: {e}")


def merge_and_average(df1: gpd.GeoDataFrame, df2: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    # Perform a full outer join
    merged = pd.merge(
        df1, df2, on=["satname", "date"], how="outer", suffixes=("_df1", "_df2")
    )

    # Identify numeric columns from both dataframes
    numeric_columns_df1 = df1.select_dtypes(include="number").columns
    numeric_columns_df2 = df2.select_dtypes(include="number").columns
    common_numeric_columns = set(numeric_columns_df1).intersection(numeric_columns_df2)

    # Average the numeric columns
    for column in common_numeric_columns:
        merged[column] = merged[[f"{column}_df1", f"{column}_df2"]].mean(axis=1)

    # Drop the original numeric columns
    merged.drop(
        columns=[f"{column}_df1" for column in common_numeric_columns]
        + [f"{column}_df2" for column in common_numeric_columns],
        inplace=True,
    )

    # Merge geometries
    geometry_columns = [col for col in merged.columns if "geometry" in col]
    merged = merge_geometries(merged, columns=geometry_columns)

    return merged


## Merge all the config_gdf.geojson files together

In [None]:
# if the shorelines or transects are at the exact same location, they will be merged into one
# if transects have different ids for the same location, they will be merged into one and both ids will be saved

merged_config  = merge_geojson_files(session_locations, merged_session_location)
merged_config 

### The ROI Listed Below Will be Merged Together

In [None]:
roi_rows = merged_config[merged_config['type'] == 'roi']
roi_rows

## Merge the Extracted Shorelines Together

In [None]:
from coastseg.merge_utils import calculate_overlap, clip_gdfs,  read_first_geojson_file, convert_lines_to_multipoints,merge_and_average
from functools import reduce

combined_gdf = gpd.GeoDataFrame( geometry=[], crs='epsg:4326')
# calculate the overlapping regions between the ROIs
overlap_gdf=calculate_overlap(roi_rows)

# read all the extracted shorelines from the session locations
gdfs = []
for session_dir in session_locations:
    # attempt to read the extracted shoreline files
    es_gdf = read_first_geojson_file(session_dir,['extracted_shorelines_points.geojson', 'extracted_shorelines.geojson'])
    es_gdf = convert_lines_to_multipoints(es_gdf)
    es_gdf = es_gdf.to_crs('epsg:4326')
    gdfs.append(es_gdf)
print(f"Read {len(gdfs)} extracted shorelines GeoDataFrames")

# clip the extracted shorelines to the overlapping regions
clipped_shorelines_gdfs=clip_gdfs(gdfs, overlap_gdf)

# sometimes there are not shorelines in the overlapping regions
if overlap_gdf.empty or len(clipped_shorelines_gdfs) == 0:
    print("No overlapping ROIs found. Sessions can be merged.")
    # merge the geodataframes on date and satname and average the cloud_cover and geoaccuracy for the merged rows

    for gdf in gdfs:
        if not gdf.crs:
            gdf.set_crs("EPSG:4326", inplace=True)
        
    # Perform a full outer join and average the numeric columns across all GeoDataFrames
    result = reduce(merge_and_average, gdfs)

    result.sort_values(by='date', inplace=True)
    result.reset_index(drop=True, inplace=True)

print(f"Combined {len(result)} rows from {len(gdfs)} GeoDataFrames")
print(f"The following dataframe contains the combined extracted shorelines from all sessions.\n Shorelines that were extracted on the same dates have been combined.")


combined_gdf = result
combined_gdf

### Save the Merged Extracted Shorelines to a JSON file
- This will contains all the metadata for each extracted shoreline such as 


      1. cloud cover
      2. date
      3. satellite it was derived from 
      4. geoaccuracy
- Filename: `extracted_shorelines_dict.json`
        

In [None]:
from coastseg import file_utilities

# mapping of dictionary keys to dataframe columns
keymap ={'shorelines':'geometry',
         'dates':'date',
         'satname':'satname',
         'cloud_cover':'cloud_cover',
         'geoaccuracy':'geoaccuracy'}
# shoreline dict should have keys: dates, satname, cloud_cover, geoaccuracy, shorelines
shoreline_dict = dataframe_to_dict(combined_gdf,keymap)
# save the extracted shoreline dictionary to json file
file_utilities.to_file(shoreline_dict, os.path.join(merged_session_location, "extracted_shorelines_dict.json"))

## New Number of Extracted Shorelines Across All ROIs

In [None]:
len(shoreline_dict['shorelines'])

### Save the Merged Extracted Shorelines to GeoJSON Files


In [None]:
from coastseg.common import convert_linestrings_to_multipoints, stringify_datetime_columns
import os
# Save extracted shorelines as a GeoJSON file
es_line_path = os.path.join(merged_session_location, "extracted_shorelines_lines.geojson")
es_pts_path = os.path.join(merged_session_location, "extracted_shorelines_points.geojson")

es_lines_gdf = convert_multipoints_to_linestrings(combined_gdf)
# save extracted shorelines as interpolated linestrings
es_lines_gdf.to_file(es_line_path, driver='GeoJSON')


points_gdf = convert_linestrings_to_multipoints(combined_gdf)
points_gdf = stringify_datetime_columns(points_gdf)
# Save extracted shorelines as mulitpoints GeoJSON file
points_gdf.to_file(es_pts_path, driver='GeoJSON')


#  Find when the Transects and Shorelines intersect
1. Loads the Transects for all the ROIs 
2. Get the shoreline dictionary we created earlier and read the shorelines from it
3. Find where the shorelines and transects intersect
4. Save the shoreline and transect intersections as a timeseries to a csv file
5. Save the timeseries of intersections between the shoreline and a single tranesct to csv file

In [None]:
from coastsat import SDS_transects
# 1. load transects for all ROIs
transect_rows = merged_config[merged_config['type'] == 'transect']
transects_dict = {row['id']: np.array(row["geometry"].coords) for i, row in transect_rows.iterrows()}
# 2. compute the intersection between the transects and the extracted shorelines
cross_distance = SDS_transects.compute_intersection_QC(shoreline_dict, transects_dict, settings_transects)

In [None]:
from coastseg.common import get_cross_distance_df
# use coastseg.common to get the cross_distance_df
transects_df = get_cross_distance_df(shoreline_dict,cross_distance)
# save the transect shoreline intersections to csv timeseries file
filepath = os.path.join(merged_session_location, "transect_time_series.csv")
transects_df.to_csv(filepath, sep=",")
transects_df.head(5)

### Save a CSV for Each Transect 
- WARNING some of these transects will contain a lot of null values because they don't intersect with other ROI's extracted shorelines

In [None]:
# Save the timeseries of intersections between the shoreline and a single tranesct to csv file
create_csv_per_transect(merged_session_location,cross_distance,shoreline_dict,)