Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 10, 2024
1 parent 3d0ea79 commit 93f4174
Showing 1 changed file with 59 additions and 65 deletions.
124 changes: 59 additions & 65 deletions nbs/240508-inference-naip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")"
]
},
Expand All @@ -18,34 +19,27 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import glob\n",
"import math\n",
"import boto3\n",
"import yaml\n",
"import os\n",
"import random\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import geopandas as gpd\n",
"from shapely import Point\n",
"from sklearn import decomposition\n",
"import lancedb\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"import numpy as np\n",
"import pandas as pd\n",
"import rioxarray # noqa: F401\n",
"import shapely # .geometry import Point, Polygon, box\n",
"import torch\n",
"import stackstac\n",
"from pystac_client import Client\n",
"import pystac_client\n",
"import xarray as xr\n",
"import yaml\n",
"from box import Box\n",
"import lancedb\n",
"from pathlib import Path\n",
"import shapely #.geometry import Point, Polygon, box\n",
"from einops import rearrange\n",
"from torchvision.transforms import v2\n",
"from pystac_client import Client\n",
"from stacchip.processors.prechip import normalize_timestamp\n",
"from torchvision.transforms import v2\n",
"\n",
"from src.datamodule import ClayDataModule\n",
"from src.model_clay_v1 import ClayMAEModule\n"
"from src.model_clay_v1 import ClayMAEModule"
]
},
{
Expand All @@ -56,11 +50,10 @@
"outputs": [],
"source": [
"def plot_rgb(stack):\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
" )\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
" plt.show()\n",
" \n",
"\n",
"\n",
"def normalize_latlon(lat, lon):\n",
" lat = lat * np.pi / 180\n",
" lon = lon * np.pi / 180\n",
Expand Down Expand Up @@ -130,13 +123,13 @@
"\n",
"\n",
"def generate_embeddings(model, datacube):\n",
" #print(datacube)\n",
" # print(datacube)\n",
" with torch.no_grad():\n",
" unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)\n",
"\n",
" # The first embedding is the class token, which is the\n",
" # overall single embedding. We extract that for PCA below.\n",
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
" return unmsk_patch[:, 0, :].cpu().numpy()"
]
},
{
Expand Down Expand Up @@ -172,11 +165,10 @@
"for item in items.get_all_items():\n",
" assets = item.assets\n",
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
" stackstac_datasets.append(dataset)\n",
" granule_names.append(granule_name)\n",
" \n",
" \n",
"\n",
"\n",
"# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
"def tile_dataset(dataset, granule_name):\n",
Expand All @@ -202,72 +194,73 @@
" y_end = y_start + 256\n",
"\n",
" # Extract the tile from the cropped dataset\n",
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
" \n",
" tile = cropped_dataset.isel(\n",
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
" )\n",
"\n",
" # Calculate the centroid\n",
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
" \n",
"\n",
" # Print or use the centroid coordinates\n",
" #print(\"Centroid X:\", centroid_x.item())\n",
" #print(\"Centroid Y:\", centroid_y.item())\n",
" \n",
" # print(\"Centroid X:\", centroid_x.item())\n",
" # print(\"Centroid Y:\", centroid_y.item())\n",
"\n",
" lon = centroid_x.item()\n",
" lat = centroid_y.item()\n",
"\n",
" tile = tile.assign_coords(band=['red','green','blue','nir'])\n",
" tile = tile.assign_coords(band=[\"red\", \"green\", \"blue\", \"nir\"])\n",
" tile_save = tile\n",
"\n",
" time_coord = xr.DataArray(['2020-01-01'], dims='time', name='time')\n",
" time_coord = xr.DataArray([\"2020-01-01\"], dims=\"time\", name=\"time\")\n",
"\n",
" # Assign the time coordinate to the DataArray\n",
" tile = tile.expand_dims(time=[0])\n",
" tile = tile.assign_coords(time=time_coord)\n",
"\n",
" gsd_coord = xr.DataArray([0.6], dims='gsd', name='gsd')\n",
" gsd_coord = xr.DataArray([0.6], dims=\"gsd\", name=\"gsd\")\n",
"\n",
" # Assign the time coordinate to the DataArray\n",
" tile = tile.expand_dims(gsd=[0])\n",
" tile = tile.assign_coords(gsd=gsd_coord)\n",
"\n",
" tile_name = f\"{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
" #name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n",
" # name_coord = xr.DataArray(tile_name, dims='filename', name='filename')\n",
"\n",
" # Assign the time coordinate to the DataArray\n",
" #tile = tile.expand_dims(filename=[0])\n",
" #tile = tile.assign_coords(filename=name_coord)\n",
" # tile = tile.expand_dims(filename=[0])\n",
" # tile = tile.assign_coords(filename=name_coord)\n",
"\n",
" #print(tile)\n",
" # print(tile)\n",
"\n",
" # Save the tile as a GeoTIFF\n",
" tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
" tile_save.rio.to_raster(tile_path)\n",
" tiles.append(tile)\n",
" tile_names.append(tile_name)\n",
" \n",
"\n",
" return tiles, tile_names\n",
" \n",
"\n",
"\n",
"make_tiles = False\n",
"\n",
"if make_tiles:\n",
" tiles_ = []\n",
" tile_names_ = []\n",
" \n",
" \n",
"\n",
" # Tile each dataset\n",
" for dataset, granule_name in zip(stackstac_datasets, granule_names):\n",
" tiles, tile_names = tile_dataset(dataset, granule_name)\n",
" tiles_.append(tiles)\n",
" tile_names_.append(tile_names)\n",
" #tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n",
" # tiles, tile_names = tile_dataset(stackstac_datasets[0], granule_names[0])\n",
" tiles__ = [tile for tile in tiles for tile_ in tiles_]\n",
" tile_names__ = [tile for tile in tile_names for tile_ in tile_names_]\n",
"else:\n",
" tiles__ = []\n",
" tile_names__ = []\n",
" for filename in os.listdir(save_dir):\n",
" if filename.endswith(\".tif\"): \n",
" if filename.endswith(\".tif\"):\n",
" tile_names__.append(filename)\n",
" file_path = os.path.join(save_dir, filename)\n",
" data_array = rioxarray.open_rasterio(file_path)\n",
Expand Down Expand Up @@ -317,7 +310,7 @@
"source": [
"model = load_model(\n",
" # ckpt=\"s3://clay-model-ckpt/v0.5.3/mae_v0.5.3_epoch-29_val-loss-0.3073.ckpt\",\n",
" #ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n",
" # ckpt=\"../checkpoints/v0.5.3/mae_v0.5.3_epoch-08_val-loss-0.3150.ckpt\",\n",
" ckpt=\"s3://clay-model-ckpt/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\",\n",
" device=\"cuda\",\n",
")\n",
Expand All @@ -327,14 +320,14 @@
" # Calculate the centroid\n",
" centroid_x = (tile.x * tile).sum() / tile.sum()\n",
" centroid_y = (tile.y * tile).sum() / tile.sum()\n",
" \n",
"\n",
" # Print or use the centroid coordinates\n",
" #print(\"Centroid X:\", centroid_x.item())\n",
" #print(\"Centroid Y:\", centroid_y.item())\n",
" \n",
" # print(\"Centroid X:\", centroid_x.item())\n",
" # print(\"Centroid Y:\", centroid_y.item())\n",
"\n",
" lon = centroid_x.item()\n",
" lat = centroid_y.item()\n",
" \n",
"\n",
" datacube = prep_datacube(tile, lat, lon, model.device)\n",
" embeddings_ = generate_embeddings(model, datacube)\n",
" embeddings.append(embeddings_)\n",
Expand All @@ -355,21 +348,20 @@
" box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n",
"\n",
" # Create the GeoDataFrame\n",
" gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\")\n",
" gdf = gpd.GeoDataFrame(\n",
" data, geometry=[box_emb], crs=f\"EPSG:{tile.rio.crs.to_epsg()}\"\n",
" )\n",
"\n",
" # Reproject to WGS84 (lon/lat coordinates)\n",
" gdf = gdf.to_crs(epsg=4326)\n",
"\n",
" outpath = (\n",
" f\"{outdir_embeddings}/\"\n",
" f\"{fname[:-4]}.gpq\"\n",
" )\n",
" outpath = f\"{outdir_embeddings}/\" f\"{fname[:-4]}.gpq\"\n",
" gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n",
" print(\n",
" f\"Saved {len(gdf)} rows of embeddings of \"\n",
" f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n",
" )\n",
" i=i+1"
" i = i + 1"
]
},
{
Expand Down Expand Up @@ -435,7 +427,7 @@
"data = []\n",
"# Dataframe to find overlaps within\n",
"gdfs = []\n",
"idx = 0\n",
"idx = 0\n",
"for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n",
" gdf = gpd.read_parquet(emb)\n",
" gdf[\"year\"] = gdf.date.dt.year\n",
Expand All @@ -456,7 +448,7 @@
" \"box\": row[\"box\"].bounds,\n",
" }\n",
" )\n",
" idx = idx+1"
" idx = idx + 1"
]
},
{
Expand Down Expand Up @@ -572,21 +564,23 @@
"source": [
"def plot(df, cols=10):\n",
" fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n",
" i=0\n",
" i = 0\n",
" for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n",
" row = df.iloc[i]\n",
" path = row[\"path\"]\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(band=['red', 'green', 'blue']) #[1,2,3])\n",
" #chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n",
" chip = rioxarray.open_rasterio(f\"{save_dir}/{path}.tif\").sel(\n",
" band=[\"red\", \"green\", \"blue\"]\n",
" ) # [1,2,3])\n",
" # chip = tiles__[row[\"idx\"]].sel(band=['red', 'green', 'blue'])\n",
" width = chip.shape[-1]\n",
" height = chip.shape[-1]\n",
" chip = chip.squeeze()\n",
" chip = chip.transpose('x', 'y', 'band')\n",
" chip = chip.transpose(\"x\", \"y\", \"band\")\n",
"\n",
" ax.imshow(chip)\n",
" ax.set_title(f\"{row['idx']}\")\n",
" ax.set_axis_off()\n",
" i=i+1\n",
" i = i + 1\n",
" plt.tight_layout()\n",
" fig.savefig(\"similar.png\")"
]
Expand Down

0 comments on commit 93f4174

Please sign in to comment.