From eb55e2366e899477ab08de567a36d0f13ba04569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 19:34:13 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/patch_level_cloud_cover.ipynb | 261 +++++++++++++++++------------ 1 file changed, 156 insertions(+), 105 deletions(-) diff --git a/docs/patch_level_cloud_cover.ipynb b/docs/patch_level_cloud_cover.ipynb index c32c9531..120daad6 100644 --- a/docs/patch_level_cloud_cover.ipynb +++ b/docs/patch_level_cloud_cover.ipynb @@ -23,29 +23,27 @@ "metadata": {}, "outputs": [], "source": [ + "import glob\n", + "from pathlib import Path\n", + "\n", "import geopandas as gpd\n", - "import pystac_client\n", - "import shapely\n", - "import stackstac\n", - "import torch\n", + "import lancedb\n", "import matplotlib.pyplot as plt\n", "import numpy\n", "import pandas as pd\n", - "import xarray as xr\n", + "import pystac_client\n", "import rasterio\n", "import rioxarray # noqa: F401\n", - "import pyarrow as pa\n", - "import pickle\n", - "import lancedb\n", - "import glob\n", - "from pathlib import Path\n", - "from shapely.geometry import Point, Polygon, box\n", - "from rasterio.enums import Resampling\n", + "import shapely\n", + "import stackstac\n", + "import torch\n", "from rasterio.enums import Resampling\n", + "from shapely.geometry import Polygon, box\n", + "\n", "from src.datamodule import ClayDataModule\n", "from src.model_clay import CLAYModule\n", "\n", - "pd.set_option('display.max_colwidth', None)\n", + "pd.set_option(\"display.max_colwidth\", None)\n", "\n", "BAND_GROUPS_L2A = {\n", " \"rgb\": [\"red\", \"green\", \"blue\"],\n", @@ -89,10 +87,10 @@ "outputs": [], "source": [ "# sample cluster\n", - "bbox_bl = (177.4199,-17.8579)\n", - "bbox_tl = (177.4156,-17.6812)\n", - "bbox_br = (177.5657,-17.8572)\n", - "bbox_tr = (177.5657,-17.6812)" + "bbox_bl = (177.4199, -17.8579)\n", + "bbox_tl = (177.4156, -17.6812)\n", + "bbox_br = (177.5657, -17.8572)\n", + "bbox_tr = (177.5657, -17.6812)" ] }, { @@ -111,7 +109,9 @@ "outputs": [], "source": [ "# Define area of interest\n", - "area_of_interest = shapely.box(xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1])\n", + "area_of_interest = shapely.box(\n", + " xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1]\n", + ")\n", "\n", "# Define temporal range\n", "daterange: dict = [\"2021-01-01T00:00:00Z\", \"2021-12-31T23:59:59Z\"]" @@ -182,7 +182,7 @@ " fill_value=0,\n", " assets=BAND_GROUPS_L2A[\"rgb\"] + BAND_GROUPS_L2A[\"scl\"],\n", " resampling=Resampling.nearest,\n", - " xy_coords='center',\n", + " xy_coords=\"center\",\n", ")\n", "\n", "stack_L2A = stack_L2A.compute()\n", @@ -217,13 +217,13 @@ " # Write tile to output dir, whilst dropping the SCL band in the process\n", " for tile in stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]):\n", " date = str(tile.time.values)[:10]\n", - " \n", + "\n", " name = \"{dir}/claytile_{date}.tif\".format(\n", " dir=outdir,\n", " date=date.replace(\"-\", \"\"),\n", " )\n", " tile.rio.to_raster(name, compress=\"deflate\")\n", - " \n", + "\n", " with rasterio.open(name, \"r+\") as rst:\n", " rst.update_tags(date=date)" ] @@ -246,6 +246,7 @@ "source": [ "# Function to count cloud pixels in a subset\n", "\n", + "\n", "def count_cloud_pixels(subset_scl, cloud_labels):\n", " cloud_pixels = 0\n", " for label in cloud_labels:\n", @@ -273,7 +274,7 @@ ], "source": [ "# Define the chunk size for tiling\n", - "chunk_size = {'x': 32, 'y': 32} # Adjust the chunk size as needed\n", + "chunk_size = {\"x\": 32, \"y\": 32} # Adjust the chunk size as needed\n", "\n", "# Tile the data\n", "ds_chunked_L2A = stack_L2A.chunk(chunk_size)\n", @@ -291,46 +292,47 @@ "cloud_pcts = {}\n", "\n", "# Get the geospatial transform and CRS\n", - "transform = ds_chunked_L2A.attrs['transform']\n", - "crs = ds_chunked_L2A.attrs['crs']\n", + "transform = ds_chunked_L2A.attrs[\"transform\"]\n", + "crs = ds_chunked_L2A.attrs[\"crs\"]\n", "\n", - "for x in range((ds_chunked_L2A.sizes['x'] // chunk_size['x']) + 1):\n", - " for y in range((ds_chunked_L2A.sizes['y'] // chunk_size['y']) + 1):\n", + "for x in range((ds_chunked_L2A.sizes[\"x\"] // chunk_size[\"x\"]) + 1):\n", + " for y in range((ds_chunked_L2A.sizes[\"y\"] // chunk_size[\"y\"]) + 1):\n", " # Compute chunk coordinates\n", - " x_start = x * chunk_size['x']\n", - " y_start = y * chunk_size['y']\n", - " x_end = min(x_start + chunk_size['x'], ds_chunked_L2A.sizes['x'])\n", - " y_end = min(y_start + chunk_size['y'], ds_chunked_L2A.sizes['y'])\n", - " \n", + " x_start = x * chunk_size[\"x\"]\n", + " y_start = y * chunk_size[\"y\"]\n", + " x_end = min(x_start + chunk_size[\"x\"], ds_chunked_L2A.sizes[\"x\"])\n", + " y_end = min(y_start + chunk_size[\"y\"], ds_chunked_L2A.sizes[\"y\"])\n", + "\n", " # Compute chunk geospatial bounds\n", " lon_start, lat_start = transform * (x_start, y_start)\n", " lon_end, lat_end = transform * (x_end, y_end)\n", - " #print(lon_start, lat_start, lon_end, lat_end, x, y)\n", + " # print(lon_start, lat_start, lon_end, lat_end, x, y)\n", "\n", " # Store chunk bounds\n", " chunk_bounds[(x, y)] = {\n", - " 'lon_start': lon_start, 'lat_start': lat_start,\n", - " 'lon_end': lon_end, 'lat_end': lat_end\n", + " \"lon_start\": lon_start,\n", + " \"lat_start\": lat_start,\n", + " \"lon_end\": lon_end,\n", + " \"lat_end\": lat_end,\n", " }\n", "\n", " # Extract the subset of the SCL band\n", - " subset_scl = ds_chunked_L2A[:,:,5][:, y_start:y_end, x_start:x_end]\n", - " \n", + " subset_scl = ds_chunked_L2A[:, :, 5][:, y_start:y_end, x_start:x_end]\n", + "\n", " # Count the cloud pixels in the subset\n", " cloud_pct = count_cloud_pixels(subset_scl, SCL_CLOUD_LABELS)\n", - " \n", + "\n", " # Store the cloud percent for this chunk\n", " cloud_pcts[(x, y)] = cloud_pct\n", "\n", "# Print chunk bounds\n", - "#for key, value in chunk_bounds.items():\n", - " #print(f\"Chunk {key}: {value}\")\n", + "# for key, value in chunk_bounds.items():\n", + "# print(f\"Chunk {key}: {value}\")\n", "\n", "# Print cloud counts\n", "for key, value in cloud_pcts.items():\n", " if value > 0:\n", - " print(f\"Chunk {key}: Cloud percentage = {value}\")\n", - "\n" + " print(f\"Chunk {key}: Cloud percentage = {value}\")" ] }, { @@ -422,9 +424,9 @@ } ], "source": [ - "print(len(embeddings[0])) # embeddings is a list\n", - "print(embeddings[0].shape) # with date and lat/lon\n", - "print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon" + "print(len(embeddings[0])) # embeddings is a list\n", + "print(embeddings[0].shape) # with date and lat/lon\n", + "print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon" ] }, { @@ -435,7 +437,7 @@ "outputs": [], "source": [ "# remove date and lat/lon and reshape to disaggregated patches\n", - "embeddings_patch = embeddings[0][:, :-2, :].reshape([1,16,16,768]) " + "embeddings_patch = embeddings[0][:, :-2, :].reshape([1, 16, 16, 768])" ] }, { @@ -525,25 +527,36 @@ " for j in range(embeddings_patch_avg_group.shape[1]):\n", " embeddings_output_patch = embeddings_patch_avg_group[i, j]\n", "\n", - " item_ = [element for element in list(chunk_bounds.items()) if element[0] == (i,j)]\n", - " box_ = [item_[0][1]['lon_start'], item_[0][1]['lat_start'],item_[0][1]['lon_end'], item_[0][1]['lat_end']]\n", - " cloud_pct_ = [element for element in list(cloud_pcts.items()) if element[0] == (i,j)]\n", + " item_ = [\n", + " element for element in list(chunk_bounds.items()) if element[0] == (i, j)\n", + " ]\n", + " box_ = [\n", + " item_[0][1][\"lon_start\"],\n", + " item_[0][1][\"lat_start\"],\n", + " item_[0][1][\"lon_end\"],\n", + " item_[0][1][\"lat_end\"],\n", + " ]\n", + " cloud_pct_ = [\n", + " element for element in list(cloud_pcts.items()) if element[0] == (i, j)\n", + " ]\n", " source_url = batch[\"source_url\"]\n", " date = batch[\"date\"]\n", " data = {\n", " \"source_url\": batch[\"source_url\"][0],\n", - " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\").astype(dtype=\"date32[day][pyarrow]\"),\n", + " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\").astype(\n", + " dtype=\"date32[day][pyarrow]\"\n", + " ),\n", " \"embeddings\": [numpy.ascontiguousarray(embeddings_output_patch)],\n", - " \"cloud_cover\": cloud_pct_[0][1]\n", + " \"cloud_cover\": cloud_pct_[0][1],\n", " }\n", - " \n", + "\n", " # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)\n", " # The box_ list is encoded as [bottom left x, bottom left y, top right x, top right y]\n", " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n", - " \n", + "\n", " # Create the GeoDataFrame\n", " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{epsg}\")\n", - " \n", + "\n", " # Reproject to WGS84 (lon/lat coordinates)\n", " gdf = gdf.to_crs(epsg=4326)\n", "\n", @@ -588,22 +601,26 @@ "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n", " gdf = gpd.read_parquet(emb)\n", " gdf[\"year\"] = gdf.date.dt.year\n", - " gdf[\"tile\"] = gdf[\"source_url\"].apply(lambda x: Path(x).stem.rsplit(\"/\")[-1].rsplit(\"_\")[0])\n", - " gdf[\"idx\"] = '_'.join(emb.split(\"/\")[-1].split(\"_\")[2:]).replace('.gpq', '')\n", + " gdf[\"tile\"] = gdf[\"source_url\"].apply(\n", + " lambda x: Path(x).stem.rsplit(\"/\")[-1].rsplit(\"_\")[0]\n", + " )\n", + " gdf[\"idx\"] = \"_\".join(emb.split(\"/\")[-1].split(\"_\")[2:]).replace(\".gpq\", \"\")\n", " gdf[\"box\"] = [box(*geom.bounds) for geom in gdf.geometry]\n", " gdfs.append(gdf)\n", - " \n", - " for _,row in gdf.iterrows():\n", - " data.append({\n", - " \"vector\": row[\"embeddings\"],\n", - " \"path\": row[\"source_url\"],\n", - " \"tile\": row[\"tile\"],\n", - " \"date\": row[\"date\"],\n", - " \"year\": int(row[\"year\"]),\n", - " \"cloud_cover\": row[\"cloud_cover\"],\n", - " \"idx\": row[\"idx\"],\n", - " \"box\": row[\"box\"].bounds,\n", - " })" + "\n", + " for _, row in gdf.iterrows():\n", + " data.append(\n", + " {\n", + " \"vector\": row[\"embeddings\"],\n", + " \"path\": row[\"source_url\"],\n", + " \"tile\": row[\"tile\"],\n", + " \"date\": row[\"date\"],\n", + " \"year\": int(row[\"year\"]),\n", + " \"cloud_cover\": row[\"cloud_cover\"],\n", + " \"idx\": row[\"idx\"],\n", + " \"box\": row[\"box\"].bounds,\n", + " }\n", + " )" ] }, { @@ -640,9 +657,9 @@ "epsg = items_L2A[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point from lon/lat to UTM projection\n", - "box_embedding = gpd.GeoDataFrame(crs=\"OGC:CRS84\", geometry=[area_of_interest_embedding]).to_crs(\n", - " epsg\n", - ")\n", + "box_embedding = gpd.GeoDataFrame(\n", + " crs=\"OGC:CRS84\", geometry=[area_of_interest_embedding]\n", + ").to_crs(epsg)\n", "geom_embedding = box_embedding.iloc[0].geometry\n", "\n", "# Create bounds of the correct size, the model\n", @@ -665,7 +682,7 @@ ")\n", "\n", "stack_embedding = stack_embedding.compute()\n", - "#assert stack_embedding.shape == (1, 4, 512, 512)\n", + "# assert stack_embedding.shape == (1, 4, 512, 512)\n", "\n", "stack_embedding.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000\n", @@ -771,7 +788,9 @@ "outputs": [], "source": [ "# Sample cloudy patch using index (7,0)\n", - "cloudy_patch_idx = [element for element in list(cloud_pcts.items()) if element[0] == (7, 0)][0][0]" + "cloudy_patch_idx = [\n", + " element for element in list(cloud_pcts.items()) if element[0] == (7, 0)\n", + "][0][0]" ] }, { @@ -783,11 +802,18 @@ "source": [ "# Search for other cloudy patches\n", "chips_cloudy = [\n", - " {\"tile\": \"claytile\", \"idx\": f\"{'_'.join(map(str, cloudy_patch_idx))}\", \"year\": 2021},\n", + " {\n", + " \"tile\": \"claytile\",\n", + " \"idx\": f\"{'_'.join(map(str, cloudy_patch_idx))}\",\n", + " \"year\": 2021,\n", + " },\n", "]\n", "filter_cloudy = \" OR \".join(\n", - " [f\"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}\" for chip in chips_cloudy]\n", - " )\n", + " [\n", + " f\"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']}\"\n", + " for chip in chips_cloudy\n", + " ]\n", + ")\n", "\n", "v_cloudy = tbl.search().where(filter_cloudy).to_pandas().iloc[0][\"vector\"]" ] @@ -818,7 +844,9 @@ "outputs": [], "source": [ "# Sample cloudy patch using index (1,0)\n", - "non_cloudy_patch_idx = [element for element in list(cloud_pcts.items()) if element[0] == (10, 10)][0][0]" + "non_cloudy_patch_idx = [\n", + " element for element in list(cloud_pcts.items()) if element[0] == (10, 10)\n", + "][0][0]" ] }, { @@ -830,12 +858,19 @@ "source": [ "# Search for other non-cloudy patches\n", "chips = [\n", - " {\"tile\": \"claytile\", \"idx\": f\"{'_'.join(map(str, non_cloudy_patch_idx))}\", \"year\": 2021},\n", + " {\n", + " \"tile\": \"claytile\",\n", + " \"idx\": f\"{'_'.join(map(str, non_cloudy_patch_idx))}\",\n", + " \"year\": 2021,\n", + " },\n", "]\n", "filter_non_cloudy = \" OR \".join(\n", - " [f\"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']} \\\n", - " AND cloud_cover == 0\" for chip in chips]\n", - " )\n", + " [\n", + " f\"(tile == '{chip['tile']}' AND idx == '{chip['idx']}') AND year == {chip['year']} \\\n", + " AND cloud_cover == 0\"\n", + " for chip in chips\n", + " ]\n", + ")\n", "\n", "v_non_cloudy = tbl.search().where(filter_non_cloudy).to_pandas().iloc[0][\"vector\"]" ] @@ -908,37 +943,39 @@ " # Define the window size\n", " window_size = (32, 32)\n", "\n", - " idxs_windows = {'idx': [], 'window': []}\n", + " idxs_windows = {\"idx\": [], \"window\": []}\n", "\n", " # Iterate over the image in 32x32 windows\n", " for col in range(0, width, window_size[0]):\n", " for row in range(0, height, window_size[1]):\n", " # Define the window\n", " window = ((row, row + window_size[1]), (col, col + window_size[0]))\n", - " \n", + "\n", " # Read the data within the window\n", " data = chip.read(window=window)\n", - " \n", + "\n", " # Get the index of the window\n", " index = (col // window_size[0], row // window_size[1])\n", - " \n", + "\n", " # Process the window data here\n", " # For example, print the index and the shape of the window data\n", - " #print(\"Index:\", index)\n", - " #print(\"Window Shape:\", data.shape)\n", - "\n", - " idxs_windows['idx'].append('_'.join(map(str, index)))\n", - " idxs_windows['window'].append(data)\n", - " \n", - " #print(idxs_windows)\n", - " \n", + " # print(\"Index:\", index)\n", + " # print(\"Window Shape:\", data.shape)\n", + "\n", + " idxs_windows[\"idx\"].append(\"_\".join(map(str, index)))\n", + " idxs_windows[\"window\"].append(data)\n", + "\n", + " # print(idxs_windows)\n", + "\n", " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n", " idx = row[\"idx\"]\n", " # Find the corresponding window based on the idx\n", - " window_index = idxs_windows['idx'].index(idx)\n", - " window_data = idxs_windows['window'][window_index]\n", - " #print(window_data.shape)\n", - " subset_img = numpy.clip((window_data.transpose(1,2,0)[:, :, :3]/10_000) * 3, 0,1)\n", + " window_index = idxs_windows[\"idx\"].index(idx)\n", + " window_data = idxs_windows[\"window\"][window_index]\n", + " # print(window_data.shape)\n", + " subset_img = numpy.clip(\n", + " (window_data.transpose(1, 2, 0)[:, :, :3] / 10_000) * 3, 0, 1\n", + " )\n", " ax.imshow(subset_img)\n", " ax.set_title(f\"{tile}/{idx}\")\n", " ax.set_axis_off()\n", @@ -1032,26 +1069,40 @@ "source": [ "# Make geodataframe of the search results\n", "# cloudy\n", - "result_cloudy_boxes = [Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]) for bbox in result_cloudy['box']]\n", + "result_cloudy_boxes = [\n", + " Polygon(\n", + " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", + " )\n", + " for bbox in result_cloudy[\"box\"]\n", + "]\n", "result_cloudy_gdf = gpd.GeoDataFrame(result_cloudy, geometry=result_cloudy_boxes)\n", "result_cloudy_gdf.crs = \"EPSG:4326\"\n", "# non-cloudy\n", - "result_non_cloudy_boxes = [Polygon([(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]) for bbox in result_non_cloudy['box']]\n", - "result_non_cloudy_gdf = gpd.GeoDataFrame(result_non_cloudy, geometry=result_non_cloudy_boxes)\n", + "result_non_cloudy_boxes = [\n", + " Polygon(\n", + " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", + " )\n", + " for bbox in result_non_cloudy[\"box\"]\n", + "]\n", + "result_non_cloudy_gdf = gpd.GeoDataFrame(\n", + " result_non_cloudy, geometry=result_non_cloudy_boxes\n", + ")\n", "result_non_cloudy_gdf.crs = \"EPSG:4326\"\n", "\n", "# Plot the AOI in RGB\n", - "stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(row=\"time\", rgb=\"band\", vmin=0, vmax=2000)\n", + "stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", + " row=\"time\", rgb=\"band\", vmin=0, vmax=2000\n", + ")\n", "\n", "# Overlay the bounding boxes of the patches identified from the similarity search\n", - "result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color='red', alpha=0.5)\n", - "result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color='blue', alpha=0.5)\n", + "result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"red\", alpha=0.5)\n", + "result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"blue\", alpha=0.5)\n", "\n", "\n", "# Set plot title and labels\n", - "plt.title('Sentinel-2 with cloudy and non-cloudy embeddings')\n", - "plt.xlabel('Longitude')\n", - "plt.ylabel('Latitude')\n", + "plt.title(\"Sentinel-2 with cloudy and non-cloudy embeddings\")\n", + "plt.xlabel(\"Longitude\")\n", + "plt.ylabel(\"Latitude\")\n", "\n", "# Show the plot\n", "plt.show()"