Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 18, 2024
1 parent fb39e4b commit 6e7e393
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 47 deletions.
10 changes: 5 additions & 5 deletions nbs/v1-inference-simsearch-naip-stacchip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
" )\n",
" plt.show()\n",
" \n",
"\n",
"def normalize_latlon(lat, lon):\n",
" \"\"\"\n",
" Normalize latitude and longitude to a range between -1 and 1.\n",
Expand Down Expand Up @@ -328,7 +328,7 @@
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
"\n",
" # The first embedding is the class token, which is the\n",
" # overall single embedding. \n",
" # overall single embedding.\n",
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
]
},
Expand Down Expand Up @@ -496,16 +496,16 @@
"embeddings = []\n",
"i = 0\n",
"for tile, fname, centroid in zip(chips, str(range(len(chips))), chip_xy):\n",
" lon, lat = chip_xy[0][0], chip_xy[0][1] \n",
" lon, lat = chip_xy[0][0], chip_xy[0][1]\n",
"\n",
" date = datetime.datetime.strptime(f'{YEAR}-06-01', '%Y-%m-%d')\n",
" date_ = time.mktime(date.timetuple())\n",
" gsd = 0.6\n",
" \n",
"\n",
" datacube = prep_datacube(np.array(tile[\"image\"]), lat, lon, pd.to_datetime(f'{YEAR}-06-01'), gsd, model.device)\n",
" embeddings_ = generate_embeddings(model, datacube)\n",
" embeddings.append(embeddings_)\n",
" \n",
"\n",
" data = {\n",
" \"source_url\": str(fname),\n",
" \"date\": date_.astype(\n",
Expand Down
86 changes: 44 additions & 42 deletions nbs/v1-inference-simsearch-naip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")"
]
},
Expand All @@ -27,32 +28,27 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import glob\n",
"import math\n",
"import boto3\n",
"import yaml\n",
"import os\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import geopandas as gpd\n",
"import lancedb\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"import numpy as np\n",
"import pandas as pd\n",
"import rioxarray # noqa: F401\n",
"import shapely\n",
"import torch\n",
"import stackstac\n",
"from pystac_client import Client\n",
"import pystac_client\n",
"import xarray as xr\n",
"import yaml\n",
"from box import Box\n",
"import lancedb\n",
"from pathlib import Path\n",
"import shapely\n",
"from einops import rearrange\n",
"from torchvision.transforms import v2\n",
"from pystac_client import Client\n",
"from stacchip.processors.prechip import normalize_timestamp\n",
"from torchvision.transforms import v2\n",
"\n",
"from src.datamodule import ClayDataModule\n",
"from src.model_clay_v1 import ClayMAEModule\n"
"from src.model_clay_v1 import ClayMAEModule"
]
},
{
Expand Down Expand Up @@ -89,7 +85,7 @@
"for item in items.get_all_items():\n",
" assets = item.assets\n",
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
" stackstac_datasets.append(dataset)\n",
" granule_names.append(granule_name)"
]
Expand All @@ -116,11 +112,10 @@
" Parameters:\n",
" stack (xarray.DataArray): The input data array containing band information.\n",
" \"\"\"\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
" )\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
" plt.show()\n",
" \n",
"\n",
"\n",
"def normalize_latlon(lat, lon):\n",
" \"\"\"\n",
" Normalize latitude and longitude to a range between -1 and 1.\n",
Expand All @@ -137,6 +132,7 @@
"\n",
" return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n",
"\n",
"\n",
"def load_model(ckpt, device=\"cuda\"):\n",
" \"\"\"\n",
" Load a pretrained Clay model from a checkpoint.\n",
Expand All @@ -155,6 +151,7 @@
" model.eval()\n",
" return model.to(device)\n",
"\n",
"\n",
"def prep_datacube(stack, lat, lon, device):\n",
" \"\"\"\n",
" Prepare a data cube for model input.\n",
Expand Down Expand Up @@ -216,6 +213,7 @@
" \"waves\": torch.tensor(waves, device=device),\n",
" }\n",
"\n",
"\n",
"def generate_embeddings(model, datacube):\n",
" \"\"\"\n",
" Generate embeddings from the model using the data cube.\n",
Expand All @@ -231,9 +229,10 @@
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
"\n",
" # The first embedding is the class token, which is the\n",
" # overall single embedding. \n",
" # overall single embedding.\n",
" return unmsk_patch[:, 0, :].cpu().numpy()\n",
"\n",
"\n",
"def tile_dataset(dataset, granule_name):\n",
" \"\"\"\n",
" Tile dataset into 256x256 image chips and drop any excess border regions.\n",
Expand Down Expand Up @@ -266,23 +265,25 @@
" y_end = y_start + 256\n",
"\n",
" # Extract the tile from the cropped dataset\n",
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
" \n",
" tile = cropped_dataset.isel(\n",
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
" )\n",
"\n",
" # Calculate the centroid\n",
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
" \n",
"\n",
" lon = centroid_x.item()\n",
" lat = centroid_y.item()\n",
"\n",
" tile = tile.assign_coords(band=['red','green','blue','nir'])\n",
" tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n",
" tile_save = tile\n",
"\n",
" time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n",
" time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n",
" tile = tile.expand_dims(time=[0])\n",
" tile = tile.assign_coords(time=time_coord)\n",
"\n",
" gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n",
" gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n",
" tile = tile.expand_dims(gsd=[0])\n",
" tile = tile.assign_coords(gsd=gsd_coord)\n",
"\n",
Expand All @@ -292,8 +293,8 @@
" tile_save.rio.to_raster(tile_path)\n",
" tiles.append(tile)\n",
" tile_names.append(tile_name)\n",
" \n",
" return tiles, tile_names\n"
"\n",
" return tiles, tile_names"
]
},
{
Expand Down Expand Up @@ -322,7 +323,7 @@
" tiles__ = []\n",
" tile_names__ = []\n",
" for filename in os.listdir(save_dir):\n",
" if filename.endswith(\".tif\"): \n",
" if filename.endswith(\".tif\"):\n",
" tile_names__.append(filename)\n",
" file_path = os.path.join(save_dir, filename)\n",
" data_array = rioxarray.open_rasterio(file_path)\n",
Expand Down Expand Up @@ -408,10 +409,10 @@
" # Calculate the centroid\n",
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
" \n",
"\n",
" lon = centroid_x.item()\n",
" lat = centroid_y.item()\n",
" \n",
"\n",
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
" embeddings_ = generate_embeddings(model, datacube)\n",
" embeddings.append(embeddings_)\n",
Expand All @@ -430,15 +431,14 @@
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
"\n",
" # Create the GeoDataFrame\n",
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
" gdf = gpd.GeoDataFrame(\n",
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
" )\n",
"\n",
" # Reproject to WGS84 (lon/lat coordinates)\n",
" gdf = gdf.to_crs(epsg=4326)\n",
"\n",
" outpath = (\n",
" f\"{outdir_embeddings}/\"\n",
" f\"{fname[:-4]}.gpq\"\n",
" )\n",
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
" print(\n",
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
Expand Down Expand Up @@ -520,7 +520,7 @@
"data = []\n",
"# Dataframe to find overlaps within\n",
"gdfs = []\n",
"idx = 0\n",
"idx = 0\n",
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
" gdf = gpd.read_parquet(emb)\n",
" gdf[\"year\"] = gdf.date.dt.year\n",
Expand All @@ -541,7 +541,7 @@
" \"box\": row[\"box\"].bounds,\n",
" }\n",
" )\n",
" idx += 1\n"
" idx += 1"
]
},
{
Expand Down Expand Up @@ -663,8 +663,10 @@
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
" row = df.iloc[i]\n",
" path = row[\"path\"]\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue'])\n",
" chip = chip.squeeze().transpose('x', 'y', 'band')\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
" band=[\"red\", \"green\", \"blue\"]\n",
" )\n",
" chip = chip.squeeze().transpose(\"x\", \"y\", \"band\")\n",
" ax.imshow(chip)\n",
" ax.set_title(f\"{row['idx']}\")\n",
" ax.set_axis_off()\n",
Expand Down

0 comments on commit 6e7e393

Please sign in to comment.