In [None]:
# | default_exp embeddings
%load_ext autoreload
%autoreload 2

In [None]:
import madewithclay.data
import madewithclay.model

# Embeddings

> Working with semantic embeddings of Earth

In [None]:
# | hide
# | export
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import List, Union
from shapely.geometry import Point

import contextily as ctx
import geopandas as gpd
import matplotlib.pyplot as plt
import nbdev
import numpy as np
import pandas as pd
import psycopg2
import rasterio
from geoalchemy2 import Geometry
from nbdev.showdoc import show_doc
from sqlalchemy import create_engine
from sqlalchemy.dialects.postgresql import ARRAY, DATE, FLOAT, VARCHAR
from tqdm import tqdm


### What are Embeddings?

Embeddings in the context of Earth Observation (EO) and machine learning are dense, low-dimensional representations of high-dimensional data. In simple terms, they are numerical vectors that capture the essence of complex data, such as satellite imagery or temporal sequences from Earth observation instruments. These vectors are generated by models like Clay through a process of learning, where the model identifies and encodes the most important features and patterns within the data.

### Importance in EO
- **Data Compression**: Embeddings condense the rich information present in satellite images into a more manageable form, facilitating easier storage and faster processing.
- **Pattern Recognition**: They enable the model to recognize and compare patterns across large datasets, which is crucial for tasks like change detection, anomaly identification, or land cover classification.
- **Semantic Interpretation**: Embeddings help in understanding the semantic content of EO data, such as differentiating between urban and forested areas, or recognizing the stages of crop growth.

### How to Use Embeddings for EO

1. **Feature Extraction**: Use Clay to process EO data and extract embeddings. These embeddings represent the key features of the data, capturing aspects like spectral signatures, texture, and temporal changes.

2. **Similarity Searches**: Employ embeddings to perform similarity searches across EO datasets. For example, by comparing embeddings, you can find areas with similar land use patterns or detect regions showing similar changes over time.

3. **Machine Learning Integration**: Embeddings can be used as input features for various machine learning models. In tasks like classification or regression, these embeddings provide a rich, pre-processed input that can significantly improve model performance.

4. **Time-Series Analysis**: For temporal EO data, embeddings can capture the dynamics of changes over time, aiding in monitoring environmental changes, urban development, or agricultural practices.

5. **Anomaly Detection**: Compare embeddings from different time periods or regions to identify anomalies or unexpected changes in the environment, such as sudden forest loss or unusual agricultural activity.

In practice, to use embeddings in EO, you would typically process your EO dataset through the Clay model to generate embeddings, and then utilize these embeddings as per your specific application needs, be it for further analysis, integration into other models, or for direct comparisons and searches.

## Generating Embeddings


You can use a Clay model to create new embeddings. 

You will need to collect and pepare the required inputs. You can use `clay.data.factory` to download and prepare the data. 



In [None]:
location = Point(12.5, 55.6) # Copenhagen
time = datetime(2019,1,1)
model_version = 0.0
local_path = Path('tmp/data')
madewithclay.data.factory(location,time,model_version,local_path);

Method not implemented yet.


### Producing embeddings from the pretrained model

Once you have the data prepared.
Step by step instructions to create embeddings for a single MGRS tile location
(e.g. 27WXN).

1. Ensure that you can access the 13-band GeoTIFF data files.

   ```
   aws s3 ls s3://clay-tiles-02/02/27WXN/
   ```

   This should report a list of filepaths if you have the correct permissions,
   otherwise, please set up authentication before continuing.

2. Download the pretrained model weights, and put them in the `checkpoints/`
   folder.

   ```bash
   aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/
   ```

   ```{tip}
   For running model inference on a large scale (hundreds or thousands of MGRS
   tiles), it is recommended to have a cloud VM instance with:

   1. A high bandwidth network (>25Gbps) to speed up data transfer from the S3
      bucket to the compute device.
   2. An NVIDIA Ampere generation GPU (e.g. A10G) or newer, which would allow
      for efficient bfloat16 dtype calculations.

   For example, an AWS g5.4xlarge instance would be a cost effective option.
   ```


Once you have a pretrained model, it is now possible to pass some input images
into the encoder part of the Vision Transformer, and produce vector embeddings
which contain a semantic representation of the image.
3. Run model inference to generate the embeddings.

   ```bash
   python trainer.py predict --ckpt_path=checkpoints/clay-small-70MT-1100T-10E.ckpt \
                             --trainer.precision=bf16-mixed \
                             --data.data_dir=s3://clay-tiles-02/02/27WXN \
                             --data.batch_size=32 \
                             --data.num_workers=16
   ```

   This should output a GeoParquet file containing the embeddings for MGRS tile
   27WXN (recall that each 10000x10000 pixel MGRS tile contains hundreds of
   smaller 512x512 chips), saved to the `data/embeddings/` folder. See the next
   sub-section for details about the embeddings file.

   ```{note}
   For those interested in how the embeddings were computed, the predict step
   above does the following:

   1. Pass the 13-band GeoTIFF input into the Vision Transformer's encoder, to
      produce raw embeddings of shape (B, 1538, 768), where B is the batch_size,
      1538 is the patch dimension and 768 is the embedding length. The patch
      dimension itself is a concatenation of 1536 (6 band groups x 16x16
      spatial patches of size 32x32 pixels each in a 512x512 image) + 2 (latlon
      embedding and time embedding) = 1538.
   2. The mean or average is taken across the 1536 patch dimension, yielding an
      output embedding of shape (B, 768).

   More details of how this is implemented can be found by inspecting the
   `predict_step` method in the `model_clay.py` file.
   ```


### Format of the embeddings file

The vector embeddings are stored in a single column within a
[GeoParquet](https://geoparquet.org) file (*.gpq), with other columns
containing spatiotemporal metadata. This file format is built on top of the
popular Apache Parquet columnar storage format designed for fast analytics,
and it is highly interoperable across different tools like QGIS,
GeoPandas (Python), sfarrow (R), and more.

#### Filename convention

The embeddings file utilizes the following naming convention:

```
{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
```

Example: `27WXN_20200101_20231231_v001.gpq`

| Variable | Description |
|--|--|
| MGRS | The spatial location of the file's contents in the [Military Grid Reference System (MGRS)](https://en.wikipedia.org/wiki/Military_Grid_Reference_System), given as a 5-character string |
| MINDATE | The minimum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format |
| MINDATE | The maximum acquisition date of the Sentinel-2 images used to generate the embeddings, given in YYYYMMDD format |
| VERSION | Version of the generated embeddings, given as a 3-digit number |


#### Table schema

Each row within the GeoParquet table is generated from a 512x512 pixel image,
and contains a record of the embeddings, spatiotemporal metadata, and a link to
the GeoTIFF file used as the source image for the embedding. The table looks
something like this:

|         source_url          |    date    |      embeddings      |   geometry   |
|-----------------------------|------------|----------------------|--------------|
| s3://.../.../claytile_*.tif | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| s3://.../.../claytile_*.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |

Details of each column are as follows:

- `source_url` ([string](https://arrow.apache.org/docs/python/generated/pyarrow.string.html)) - The full URL to the 13-band GeoTIFF image the embeddings were derived from.
- `date` ([date32](https://arrow.apache.org/docs/python/generated/pyarrow.date32.html)) - Acquisition date of the Sentinel-2 image used to generate the embeddings, in YYYY-MM-DD format.
- `embeddings` ([FixedShapeTensorArray](https://arrow.apache.org/docs/python/generated/pyarrow.FixedShapeTensorArray.html)) - The vector embeddings given as a 1-D tensor or list with a length of 768.
- `geometry` ([binary](https://arrow.apache.org/docs/python/generated/pyarrow.binary.html)) - The spatial bounding box of where the 13-band image, provided in a [WKB](https://en.wikipedia.org/wiki/Well-known_text_representation_of_geometry#Well-known_binary) Polygon representation.


```{note}
Additional technical details of the GeoParquet file:
- GeoParquet specification [v1.0.0](https://geoparquet.org/releases/v1.0.0)
- Coordinate reference system of geometries are in `OGC:CRS84`.
```


### Embeddings Factory

If you don't have embeddings, you'll need to use the "Embeddings Factory". It uses a given location and time, and a Clay model, to generate the embeddgins for each input data bundle.


## Working with embeddings

A Clay embedding filename will look like this `33PWP_20181021_20200114_v001.gpq` which is a concatenation of the following:

* `33PWP` - the location of the input data it comes from, in MGRS format.
* `20181021` - the earliest date for any band of the input data it comes from
* `20200114` - the latest date for any band of the input data it comes from
* `v001` - the embedding version number.
* `.gpq` - the file extension, geoparquet.

Inside each file there will be as many rows as chips the MGRS tile was split into. as  and each row will have a column for each of the embedding dimensions. The number of dimensions will depend on the Clay model used to generate the embeddings.

In [None]:
# | export
class EmbeddingsHandler:
    def __init__(
        self,
        path: Path,  # Path to the file or folder with files
        max_files: int = None,
    ):  # Max number of files to load, randomly chosen
        self.path = Path(path)
        self.gdf = None
        self.files = None

        # handle path
        if self.path.is_dir():
            self.files = list(self.path.glob("*.gpq"))
            if max_files is not None:
                rng = np.random.default_rng()
                self.files = rng.choice(self.files, size=max_files, replace=False)
            assert len(self.files) > 0, "No gpq files found in path"
        else:
            self.files = [self.path]
            assert self.path.suffix == ".gpq", "File must be a gpq file"
        self.load_geoparquet_folder()

    def load_geoparquet_folder(
        self,
    ):
        "Load geoparquet files calling read_embeddings_file in parallel"
        with ProcessPoolExecutor() as executor:
            gdfs = list(
                tqdm(
                    executor.map(self.read_geoparquet_file, self.files),
                    total=len(self.files),
                )
            )
        print(f"Total rows: {sum([len(gdf) for gdf in gdfs])}\n Merging dataframes...")
        gdf = pd.concat(gdfs, ignore_index=True)
        gdf = gdf.drop('index', axis=1)
        self.gdf = gdf
        print("Done!\n Total rows: ", len(self.gdf))

    def read_geoparquet_file(self, 
                             file: Path):  # Path to the geoparquet file
        """
        Reads a geoparquet file and returns a dataframe with the embeddings.
        """
        assert file.exists(), "Path does not exist"
        # check pattern of file name like 33PWP_20181021_20200114_v001.gpq
        assert file.suffix == ".gpq", "File must be a gpq file"
        parts = file.stem.split("_")
        n_parts = len("33PWP_20181021_20200114_v001".split("_"))
        assert len(parts) == n_parts, "File name must have 4 parts"
        location, start_date, end_date, version = parts

        # read file
        gdf = gpd.read_parquet(file)
        gdf = gdf.to_crs("EPSG:3857")

        # add centroid x and y columns
        gdf["x"] = gdf.geometry.centroid.x
        gdf["y"] = gdf.geometry.centroid.y

        # set columns for the values of location, start_date, end_date, version
        gdf["location"] = location
        gdf["start_date"] = datetime.strptime(start_date, "%Y%m%d")
        gdf["end_date"] = datetime.strptime(end_date, "%Y%m%d")
        gdf["version"] = version
        return gdf

    def transform_crs(self, crs="epsg:3857"):  # CRS to transform to
        """
        Transforms the CRS of the dataframe.
        """
        self.gdf = self.gdf.to_crs(crs)

    def plot_locations(
        self,
        figsize: [int, int] = (10, 10),  # Size of the plot
        alpha: float = 0.2,  # Transparency of the points
        max_rows: int = 10000,  # Random max number of rows to plot
        bounds: List[int] = None, # Bounds of the plot [xmin, ymin, xmax, ymax]
        indices: List[int] = None # Indices of the rows to plot
    ):
        """
        Plots the dataframe on a map with an OSM underlay.
        """

        # Default to all indices if none are provided
        if indices is None:
            indices = self.gdf.index.values

        if max_rows is not None and len(indices) > max_rows:
            self.gdf = self.gdf.drop_duplicates(subset=["geometry"])
            rng = np.random.default_rng()
            indices = rng.choice(indices, size=max_rows, replace=False)
        ax = self.gdf.loc[indices].plot(
                figsize=figsize, alpha=alpha, edgecolor='k', markersize=1
            )

        # If bounds are provided, set the bounds of the plot
        if bounds is not None:
            ax.set_xlim(bounds[0], bounds[2])
            ax.set_ylim(bounds[1], bounds[3])

        ctx.add_basemap(ax, source=ctx.providers.OpenStreetMap.Mapnik)
        ax.set_axis_off()
        plt.show()

    def fetch_and_plot_image(
        self,
        index: int,  # index of the row to plot
        local_folder: Path,  # Local folder to save the image
        force_fetch: bool,  # Whether to force fetching the image
        bands: List[int] = [3, 4, 2],
    ):  # Bands to read
        """
        Fetches an image from a URL or local path, reads RGB bands, and plots it.
        """
        row = self.gdf.loc[index]

        if row["local_path"] is None:
            force_fetch = True

        if force_fetch:
            # print(f"Fetching image for row {index}")
            url = row["source_url"]
            local_path = local_folder / Path(url).name
            assert local_folder.exists(), f"Local folder {local_path} does not exist"
            with rasterio.open(url) as src:
                # print(f"Reading {bands} bands from {url}")
                rgb = src.read(bands)
            with rasterio.open(
                local_path,
                "w",
                driver="GTiff",
                height=rgb.shape[1],
                width=rgb.shape[2],
                count=len(bands),
                dtype=rgb.dtype,
                crs=src.crs,
                transform=src.transform,
            ) as dst:
                # print(f"Writing {bands} bands to {local_path}")
                dst.write(rgb)
                self.gdf.loc[self.gdf["source_url"] == url, "local_path"] = str(local_path)
        else:
            #print(f"Reading local image for row {index}")
            local_path = row["local_path"]
            with rasterio.open(local_path) as src:
                # print(f"Reading {bands} bands from {local_path}")
                rgb = src.read()
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
        rgb = np.transpose(rgb, [1, 2, 0])

        #clip values on each band to 10-90 percentile
        percentiles = np.percentile(rgb, [10, 90], axis=(0, 1))
        rgb = np.clip(rgb, percentiles[0], percentiles[1])
        return rgb

    def rgb_imgs(
        self,
        row_indices: Union[int, List[int]],  # Indices of the rows to plot
        local_folder: Path = None,  # Local folder to save the image
        force_fetch: bool = False,
        skip_plot: bool = False # skip plotting, but save them to local_folder
    ):
        """
        Plots RGB images for specified rows,
        either from local storage or by fetching them.
        """

        if isinstance(row_indices, int):
            row_indices = [row_indices]

        if "local_path" not in self.gdf.columns:
            self.gdf["local_path"] = None

        if force_fetch and local_folder is None:
            if self.gdf["local_path"].notnull().any():
                existing_files = self.gdf[self.gdf["local_path"].notnull()].iloc[0]
                local_folder = Path(existing_files["local_path"]).parent
            else:
                raise ValueError("local_folder must be provided if force_fetch is True")

        with ThreadPoolExecutor(max_workers=50) as executor:
            results = list(tqdm(executor.map(
                lambda idx: self.fetch_and_plot_image(idx,
                                                      local_folder,
                                                      force_fetch),
                                                      row_indices),
                                                      total=len(row_indices)))

        if not skip_plot:
            self._plot_images(results, row_indices)

    def _plot_images(self,
                     images  , # list of images
                     indices): # list of indices
        """
        Plots the images from the results of the fetch_and_plot_image method.
        """
        num_images = len(images)
        num_cols = min(3, num_images)
        num_rows = -(-num_images // num_cols)  # Ceiling division

        # Create a figure and a set of subplots
        fig, axes = plt.subplots(nrows=num_rows,
                                 ncols=num_cols,
                                 figsize=(5 * num_cols, 5 * num_rows))
        axes = axes.flatten() if num_images > 1 else [axes]

        for idx, image in enumerate(images):
            if image is not None:
                ax = axes[idx]
                ax.imshow(image)
                ax.axis('off')
                ax.set_title(f"Index: {indices[idx]}")

        # Turn off axes for any unused subplots
        for ax in axes[num_images:]:
            ax.axis('off')

        plt.tight_layout()
        plt.show()

    def prep_posgres(self,
                     db_url: str): # Postgres database URL
        """
        Prepares a Postgres database for the embeddings.
        """
        self.db_url = db_url
        self.engine = create_engine(db_url) if db_url else None

    def save_to_postgres(self):
        """
        Saves the geodataframe to PostgreSQL.
        """

        if self.db_url is None or self.engine is None:
            raise ValueError("Database URL not provided or engine not initialized."+
                             "call .prep_posgres(db_url) first.")

        # Convert 'embeddings' from list to numpy array
        self.gdf['embeddings'] = self.gdf['embeddings'].apply(lambda x: np.array(x))

        # Define column types for PostgreSQL
        column_types = {
            'geometry': Geometry('POLYGON', srid=3857),
            'embeddings': ARRAY(FLOAT),
            'source_url': VARCHAR,
            'local_url': VARCHAR,
            'date': DATE,
            'x': FLOAT,
            'y': FLOAT,
            'location': VARCHAR,
            'start_date': DATE,
            'end_date': DATE,
            'version': VARCHAR
        }

        #check that all keys in column_types are in gdf.columns and print missing keys
        missing_keys = set(column_types.keys()) - set(self.gdf.columns)
        if len(missing_keys) > 0:
            print(f"Missing keys: {missing_keys}")
            raise ValueError("Missing keys in gdf.columns")

        self.gdf.to_sql('embeddings', self.engine, if_exists='replace', index=False, dtype=column_types)

        # Save to PostgreSQL
        self.gdf.to_sql('embeddings', self.engine, if_exists='replace', index=False, dtype=column_types)

`EmbeddingsHandler` has several methods to help you work with embeddings.

This is how you can load embeddings from a file or folder with files, including limiting the number of embeddings to load:

In [None]:
#show_doc(EmbeddingsHandler.read_geoparquet_file)

For example, this is how to read up to 10 random files from a folder:


In [None]:
embeddings_path = Path("../fixtures/sample_embedding/01WCP_20170701_20210603_v001.gpq")
embeddings = EmbeddingsHandler(embeddings_path, max_files=10)

100%|██████████| 1/1 [00:00<00:00, 24.82it/s]

Total rows: 5
 Merging dataframes...
Done!
 Total rows:  5





Then you can plot the embeddings:

In [None]:
#show_doc(EmbeddingsHandler.plot_locations)

ModuleNotFoundError: No module named 'clay'

ModuleNotFoundError: No module named 'clay'

In [None]:
embeddings.plot_locations()

In [None]:
embeddings.plot_locations(indices=[0,1,2,3], max_rows=2)

If the total areas is too big, you can visualize the embeddings areas on detail zoomin in around one:

In [None]:
# Get the coordinates of one geometry
first_geometry = embeddings.gdf.loc[0].geometry
# Create a 1km buffer around the first geometry
buffer = first_geometry.buffer(100 * 1000)  # 100 x 1km

bounds = buffer.bounds

# Call the plot method with the bounds
embeddings.plot_locations(bounds=bounds)

Note that we are using a transparency `alpha=0.2`. Different shades of darkness are locations where there are several embeddings stacked on top of each other, i.e. from different times. 

To retrieve the RGB image for a given embedding, you can use the `rgb_imgs` method. the first time it will use the `S3` url location to pull only the RGB bands, then save it locally for faster later retrieval.

You must specify the rows you want to retrieve, and if the first time, the output folder where to save the images, if it can't reuse an existing local folder.

In [None]:
local_path = Path("tmp/rgbs/")
local_path.mkdir(parents=True,exist_ok=True)

embeddings.rgb_imgs(
    [0,1,2], local_folder=Path(local_path)
)

You can skip the `local_folder` argument if you already have other local rgb saved.

In [None]:
embeddings.rgb_imgs(2)

If needed you can `force_fetch` from the `S3` location again.


In [None]:
embeddings.rgb_imgs(0, force_fetch=True)

In [None]:
#| hide
nbdev.nbdev_export()