Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 8, 2024
1 parent 506a664 commit aae436c
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions nbs/240508-inference-naip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")"
]
},
Expand All @@ -18,21 +19,16 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from einops import rearrange\n",
"import torch\n",
"import stackstac\n",
"from pystac_client import Client\n",
"import boto3\n",
"import xarray as xr\n",
"import numpy as np\n",
"import os\n",
"import rioxarray\n",
"from box import Box\n",
"import torch\n",
"import yaml\n",
"from box import Box\n",
"from pystac_client import Client\n",
"\n",
"from src.datamodule import ClayDataModule\n",
"from src.model_clay_v1 import ClayMAEModule"
]
},
Expand All @@ -44,11 +40,10 @@
"outputs": [],
"source": [
"def plot_rgb(stack):\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
" )\n",
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
" plt.show()\n",
" \n",
"\n",
"\n",
"def normalize_latlon(lat, lon):\n",
" lat = lat * np.pi / 180\n",
" lon = lon * np.pi / 180\n",
Expand Down Expand Up @@ -124,7 +119,7 @@
"\n",
" # The first embedding is the class token, which is the\n",
" # overall single embedding. We extract that for PCA below.\n",
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
" return unmsk_patch[:, 0, :].cpu().numpy()"
]
},
{
Expand Down Expand Up @@ -61598,11 +61593,10 @@
" assets = item.assets\n",
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
" print(\"dataset: \", dataset)\n",
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
" stackstac_datasets.append(dataset)\n",
" granule_names.append(granule_name)\n",
" \n",
" \n",
"\n",
"\n",
"# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
"def tile_dataset(dataset, granule_name):\n",
Expand All @@ -61628,26 +61622,27 @@
" y_end = y_start + 256\n",
"\n",
" # Extract the tile from the cropped dataset\n",
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
" tile = cropped_dataset.isel(\n",
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
" )\n",
" print(tile.shape)\n",
"\n",
" # Save the tile as a GeoTIFF\n",
" tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
" #tile.rio.to_raster(tile_path)\n",
" # tile.rio.to_raster(tile_path)\n",
" print(tile)\n",
" tiles.append(tile)\n",
" \n",
"\n",
" return tiles\n",
" \n",
"\n",
"\n",
"# Tile each dataset\n",
"for dataset, granule_name in zip(stackstac_datasets[0:2], granule_names[0:2]):\n",
" tiles = tile_dataset(dataset, granule_name)\n",
" \n",
"\n",
"tile_0 = tiles[0]\n",
"\n",
"plot_rgb(tile_0)\n"
"plot_rgb(tile_0)"
]
},
{
Expand Down

0 comments on commit aae436c

Please sign in to comment.