Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 15, 2024
1 parent 0c72b7e commit 82a68fb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 81 deletions.
82 changes: 43 additions & 39 deletions nbs/v1-inference-simsearch-naip-stacchip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")"
]
},
Expand All @@ -27,34 +28,28 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import glob\n",
"import math\n",
"import boto3\n",
"import yaml\n",
"import os\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import geopandas as gpd\n",
"import lancedb\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pystac_client\n",
"import rioxarray # noqa: F401\n",
"import shapely\n",
"import torch\n",
"import stackstac\n",
"from pystac_client import Client\n",
"import pystac_client\n",
"import yaml\n",
"from box import Box\n",
"import lancedb\n",
"from pathlib import Path\n",
"import shapely\n",
"from einops import rearrange\n",
"from torchvision.transforms import v2\n",
"from stacchip.processors.prechip import normalize_timestamp\n",
"from stacchip.indexer import NoStatsChipIndexer\n",
"from stacchip.chipper import Chipper\n",
"from stacchip.indexer import NoStatsChipIndexer\n",
"from stacchip.processors.prechip import normalize_timestamp\n",
"from torchvision.transforms import v2\n",
"\n",
"from src.datamodule import ClayDataModule\n",
"from src.model_clay import CLAYModule\n"
"from src.model_clay import CLAYModule"
]
},
{
Expand All @@ -65,7 +60,9 @@
"outputs": [],
"source": [
"# Query STAC catalog for NAIP data\n",
"catalog = pystac_client.Client.open(\"https://planetarycomputer.microsoft.com/api/stac/v1\") #\"https://earth-search.aws.element84.com/v1\")\n",
"catalog = pystac_client.Client.open(\n",
" \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n",
") # \"https://earth-search.aws.element84.com/v1\")\n",
"\n",
"\n",
"items = catalog.search(\n",
Expand Down Expand Up @@ -99,7 +96,7 @@
"\n",
" # Get first chip for the \"image\" asset key\n",
" for chip_id in random.sample(range(0, len(chipper)), 5):\n",
" chips.append(chipper[chip_id][\"image\"])\n"
" chips.append(chipper[chip_id][\"image\"])"
]
},
{
Expand Down Expand Up @@ -128,7 +125,7 @@
}
],
"source": [
"fig, ax = plt.subplots(1, 1, gridspec_kw={'wspace': 0.01, 'hspace': 0.01}, squeeze=True)\n",
"fig, ax = plt.subplots(1, 1, gridspec_kw={\"wspace\": 0.01, \"hspace\": 0.01}, squeeze=True)\n",
"\n",
"chip = chips[0]\n",
"# Visualize the data\n",
Expand Down Expand Up @@ -160,11 +157,10 @@
" Parameters:\n",
" stack (xarray.DataArray): The input data array containing band information.\n",
" \"\"\"\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
" )\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
" plt.show()\n",
" \n",
"\n",
"\n",
"def normalize_latlon(lat, lon):\n",
" \"\"\"\n",
" Normalize latitude and longitude to a range between -1 and 1.\n",
Expand All @@ -181,6 +177,7 @@
"\n",
" return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
"\n",
"\n",
"def load_model(ckpt, device=\"cuda\"):\n",
" \"\"\"\n",
" Load a pretrained Clay model from a checkpoint.\n",
Expand All @@ -194,11 +191,16 @@
" \"\"\"\n",
" torch.set_default_device(device)\n",
" model = CLAYModule.load_from_checkpoint(\n",
" ckpt, metadata_path=\"../configs/metadata.yaml\", shuffle=False, mask_ratio=0, model_size=\"medium\"\n",
" ckpt,\n",
" metadata_path=\"../configs/metadata.yaml\",\n",
" shuffle=False,\n",
" mask_ratio=0,\n",
" model_size=\"medium\",\n",
" )\n",
" model.eval()\n",
" return model.to(device)\n",
"\n",
"\n",
"def prep_datacube(stack, lat, lon, device):\n",
" \"\"\"\n",
" Prepare a data cube for model input.\n",
Expand Down Expand Up @@ -260,6 +262,7 @@
" \"waves\": torch.tensor(waves, device=device),\n",
" }\n",
"\n",
"\n",
"def generate_embeddings(model, datacube):\n",
" \"\"\"\n",
" Generate embeddings from the model using the data cube.\n",
Expand All @@ -275,8 +278,8 @@
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
"\n",
" # The first embedding is the class token, which is the\n",
" # overall single embedding. \n",
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
" # overall single embedding.\n",
" return unmsk_patch[:, 0, :].cpu().numpy()"
]
},
{
Expand Down Expand Up @@ -337,10 +340,10 @@
" # Calculate the centroid\n",
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
" \n",
"\n",
" lon = centroid_x.item()\n",
" lat = centroid_y.item()\n",
" \n",
"\n",
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
" embeddings_ = generate_embeddings(model, datacube)\n",
" embeddings.append(embeddings_)\n",
Expand All @@ -359,15 +362,14 @@
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
"\n",
" # Create the GeoDataFrame\n",
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
" gdf = gpd.GeoDataFrame(\n",
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
" )\n",
"\n",
" # Reproject to WGS84 (lon/lat coordinates)\n",
" gdf = gdf.to_crs(epsg=4326)\n",
"\n",
" outpath = (\n",
" f\"{outdir_embeddings}/\"\n",
" f\"{fname[:-4]}.gpq\"\n",
" )\n",
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
" print(\n",
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
Expand Down Expand Up @@ -438,7 +440,7 @@
"data = []\n",
"# Dataframe to find overlaps within\n",
"gdfs = []\n",
"idx = 0\n",
"idx = 0\n",
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
" gdf = gpd.read_parquet(emb)\n",
" gdf[\"year\"] = gdf.date.dt.year\n",
Expand All @@ -459,7 +461,7 @@
" \"box\": row[\"box\"].bounds,\n",
" }\n",
" )\n",
" idx += 1\n"
" idx += 1"
]
},
{
Expand Down Expand Up @@ -560,8 +562,10 @@
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
" row = df.iloc[i]\n",
" path = row[\"path\"]\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n",
" chip = chip.squeeze().transpose('x', 'y', 'band')\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
" band=[\"red\", \"green\", \"blue\"]\n",
" )\n",
" chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n",
" ax.imshow(chip)\n",
" ax.set_title(f\"{row['idx']}\")\n",
" ax.set_axis_off()\n",
Expand Down
Loading

0 comments on commit 82a68fb

Please sign in to comment.