# Shapely Validate Example 

> Parallel handling of of a mixture of valid and invalid geometries using [regular](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.udf.html?highlight=udf#pyspark.sql.functions.udf) and [vectorized pandas](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html?highlight=pandas%20udf#pyspark.sql.functions.pandas_udf) UDFs.

__Libraries__

<p/>

* 'databricks-mosaic' (installs geopandas and dependencies as well as keplergl)

--- 
 __Last Update__ 22 NOV 2023 [Mosaic 0.3.12]

## Setup

### Imports

In [0]:
%pip install "databricks-mosaic<0.4,>=0.3" --quiet # <- Mosaic 0.3 series
# %pip install "databricks-mosaic<0.5,>=0.4" --quiet # <- Mosaic 0.4 series (as available)

In [0]:
# -- configure AQE for more compute heavy operations
#  - choose option-1 or option-2 below, essential for REPARTITION!
# spark.conf.set("spark.databricks.optimizer.adaptive.enabled", False) # <- option-1: turn off completely for full control
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False) # <- option-2: just tweak partition management
spark.conf.set("spark.sql.shuffle.partitions", 10_000)                 # <-- default is 200

# -- import databricks + spark functions
from pyspark.databricks.sql import functions as dbf
from pyspark.sql import functions as F
from pyspark.sql.functions import udf, col
from pyspark.sql.types import *

# -- setup mosaic
import mosaic as mos

mos.enable_mosaic(spark, dbutils)
# mos.enable_gdal(spark) # <- not needed for this example

# --other imports
import geopandas as gpd
import json
import matplotlib.pyplot as plt
import shapely
import warnings

warnings.simplefilter("ignore")

### Data

> Generating a dataset with some bad data, adapted from [here](https://github.com/kleunen/boost_geometry_correct).

These are the types of issues that can come up with geometries [[1](https://stackoverflow.com/questions/49902090/dataset-of-invalid-geometries-in-boostgeometry)]...

```
//Hole Outside Shell
check("POLYGON((0 0, 10 0, 10 10, 0 10, 0 0), (15 15, 15 20, 20 20, 20 15, 15 15))");
//Nested Holes
check("POLYGON((0 0, 10 0, 10 10, 0 10, 0 0), (2 2, 2 8, 8 8, 8 2, 2 2), (3 3, 3 7, 7 7, 7 3, 3 3))");
//Disconnected Interior
check("POLYGON((0 0, 10 0, 10 10, 0 10, 0 0), (5 0, 10 5, 5 10, 0 5, 5 0))");
//Self Intersection
check("POLYGON((0 0, 10 10, 0 10, 10 0, 0 0))");
//Ring Self Intersection
check("POLYGON((5 0, 10 0, 10 10, 0 10, 0 0, 5 0, 3 3, 5 6, 7 3, 5 0))");
//Nested Shells
check<multi>("MULTIPOLYGON(((0 0, 10 0, 10 10, 0 10, 0 0)),(( 2 2, 8 2, 8 8, 2 8, 2 2)))");
//Duplicated Rings
check<multi>("MULTIPOLYGON(((0 0, 10 0, 10 10, 0 10, 0 0)),((0 0, 10 0, 10 10, 0 10, 0 0)))");
//Too Few Points
check("POLYGON((2 2, 8 2))");
//Invalid Coordinate
check("POLYGON((NaN 3, 3 4, 4 4, 4 3, 3 3))");
//Ring Not Closed
check("POLYGON((0 0, 0 10, 10 10, 10 0))");
```

In [0]:
test_wkts = []

__[1a] Polygon self-intersection__

> Exterior xy plot with shapely (to see the lines).

In [0]:
test_wkts.append((1, """POLYGON ((5 0, 2.5 9, 9.5 3.5, 0.5 3.5, 7.5 9, 5 0))"""))

In [0]:
plt.plot(*shapely.wkt.loads(test_wkts[0][1]).exterior.xy)

__[1b] Polygon with hole inside__

> Exterior xy plot with shapely (to see the lines).

In [0]:
test_wkts.append((2, """POLYGON ((55 10, 141 237, 249 23, 21 171, 252 169, 24 89, 266 73, 55 10))"""))

In [0]:
plt.plot(*shapely.wkt.loads(test_wkts[1][1]).exterior.xy)

__[1c] Polygon with multiple intersections at same point__

> Exterior xy plot with shapely (to see the lines).

In [0]:
test_wkts.append((3, """POLYGON ((0 0, 10 0, 0 10, 10 10, 0 0, 5 0, 5 10, 0 10, 0 5, 10 5, 10 0, 0 0))"""))

In [0]:
plt.plot(*shapely.wkt.loads(test_wkts[2][1]).exterior.xy)

__[1d] Valid Polygon__

In [0]:
test_wkts.append((4, """POLYGON (( -84.3641541604937 33.71316821215546, -84.36414611386687 33.71303657522174, -84.36409515189553 33.71303657522174, -84.36410319852232 33.71317267442025, -84.3641541604937 33.71316821215546 ))"""))

In [0]:
plt.plot(*shapely.wkt.loads(test_wkts[3][1]).exterior.xy)

__[2] Make Spark DataFrame from `test_wkts`__

In [0]:
df = (
  spark
    .createDataFrame(test_wkts, schema=['row_id', 'geom_wkt'])
)
print(f"count? {df.count():,}")
df.display()

## Regular UDF: Test + Fix Validity

> Will use Mosaic to initially test; then only apply UDF to invalids

### UDFs

In [0]:
@udf(returnType=StringType())
def explain_wkt_validity(geom_wkt:str) -> str:
    """
    Add explanation of validity or invalidity
    """
    from shapely import wkt
    from shapely.validation import explain_validity

    _geom = wkt.loads(geom_wkt)
    return explain_validity(_geom)


@udf(returnType=StringType())
def make_wkt_valid(geom_wkt:str) -> str:
    """
    - test for wkt being valid
    - attempts to make valid
    - may have to change type, e.g. POLYGON to MULTIPOLYGON
     returns valid wkt
    """
    from shapely import wkt 
    from shapely.validation import make_valid

    _geom = wkt.loads(geom_wkt)
    if _geom.is_valid:
        return geom_wkt
    _geom_fix = make_valid(_geom)
    return _geom_fix.wkt

### Test Validity

In [0]:
df_test_valid = (
  df
    .withColumn("is_valid", mos.st_isvalid("geom_wkt"))
)

df_test_valid.display()

__Let's get an explanation for our 3 invalids__

_Recommend `explain_wkt_valid` only to help you understand, not as part of production pipeline, so doing separately._

In [0]:
display(
  df_test_valid
  .select(
    "*",
    F
      .when(col("is_valid") == False, explain_wkt_validity("geom_wkt"))
      .otherwise(F.lit(None))
      .alias("invalid_explain")
  )
  .filter("is_valid = false")
)

### Fix Validity

In [0]:
df_valid = (
  df
    .withColumnRenamed("geom_wkt", "orig_geom_wkt")
    .withColumn("is_orig_valid", mos.st_isvalid("orig_geom_wkt"))
  .select(
    "*",
    F
      .when(col("is_orig_valid") == False, make_wkt_valid("orig_geom_wkt"))
      .otherwise(col("orig_geom_wkt"))
      .alias("geom_wkt")
  )
  .withColumn("is_valid", mos.st_isvalid("geom_wkt"))
  .drop("orig_geom_wkt")
)

print(f"""count? {df_valid.count():,}""")
print(f"""num orig invalid? {df_valid.filter(col("is_orig_valid") == False).count():,}""")
print(f"""num final invalid? {df_valid.filter(col("is_valid") == False).count():,}""")
display(df_valid)

In [0]:
fix_wkts = df_valid.orderBy('row_id').toJSON().collect()
fix_wkts

__Row 1: Fixed [Self-Intersection]__ 

> Using GeoPandas to plot area for fixed.

In [0]:
gpd.GeoSeries(shapely.wkt.loads(json.loads(fix_wkts[0])['geom_wkt'])).plot()

__Row 2: Fixed [Self-Intersection]__

> Using GeoPandas to plot area for fixed.

In [0]:
gpd.GeoSeries(shapely.wkt.loads(json.loads(fix_wkts[1])['geom_wkt'])).plot()

__Row 3: Fixed [Ring Self-Intersection]__

> Using GeoPandas to plot area for fixed.

In [0]:
gpd.GeoSeries(shapely.wkt.loads(json.loads(fix_wkts[2])['geom_wkt'])).plot()

## Option: Vectorized Pandas UDF

> If you want to go further with performance, you can use a vectorized pandas UDF

__Note: We are using the Pandas Series [Vectorized UDF](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html) variant.__

In [0]:
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())
def vectorized_make_wkt_valid(s:pd.Series) -> pd.Series:
    """
    - test for wkt being valid
    - attempts to make valid
    - may have to change type, e.g. POLYGON to MULTIPOLYGON
     returns valid wkt
    """
    from shapely import wkt 
    from shapely.validation import make_valid

    def to_valid(w:str) -> str:
      _geom = wkt.loads(w)
      if _geom.is_valid:
        return w
      _geom_fix = make_valid(_geom)
      return _geom_fix.wkt

    return s.apply(to_valid) 

_This variation doesn't show all the interim testing, just the fixing._

In [0]:
df_valid1 = (
  df                                                                                   # <- initial dataframe
    .withColumnRenamed("geom_wkt", "orig_geom_wkt")
    .withColumn("is_orig_valid", mos.st_isvalid("orig_geom_wkt"))
    .repartition(sc.defaultParallelism * 8, "orig_geom_wkt")                           # <- useful at scale
  .select(
    "*",
    F
      .when(col("is_orig_valid") == False, vectorized_make_wkt_valid("orig_geom_wkt")) # <- Pandas UDF
      .otherwise(col("orig_geom_wkt"))
      .alias("geom_wkt")
  )
  .withColumn("is_valid", mos.st_isvalid("geom_wkt"))
  .drop("orig_geom_wkt")
)

print(f"""count? {df_valid1.count():,}""")
print(f"""num orig invalid? {df_valid1.filter(col("is_orig_valid") == False).count():,}""")
print(f"""num final invalid? {df_valid1.filter(col("is_valid") == False).count():,}""")
display(df_valid1)

> _To further optimize as an automated workflow, you would writing to Delta Tables and avoiding unnecessary calls to `count` / `display`._

__Notes:__

* At-scale, there are benefits to adding call like `.repartition(sc.defaultParallelism * 8, "orig_geom_wkt")` when coupled with spark confs to adjust AQE (see top of notebook) as this give you more control of partitioning since there is compute-heavy (aka UDF) tasks that Spark cannot plan for as well as a "pure" data-heavy operation.
* The focus of this notebook was not on rendering on a map, so we just used matplot lib with both Shapely (for pre-fixed geoms) and GeoPandas (for fixed geoms)
* The use of `.when()` conditional allows us to avoid UDF calls except where `is_valid=False` which saves on unnecessary compute time
* We avoided shapely `explain_validity` call except to initially understand as that call can be computationally expensive (and is only informational)
* This is just a subset of validation, but hopefully offers enough breadcrumbs for common issues you may face when processing invalid geometries