diff --git a/notebooks/231220-custom-mae-embeddings-finetune.ipynb b/notebooks/231220-custom-mae-embeddings-finetune.ipynb index 26836ca7..5adcda54 100644 --- a/notebooks/231220-custom-mae-embeddings-finetune.ipynb +++ b/notebooks/231220-custom-mae-embeddings-finetune.ipynb @@ -7,7 +7,6 @@ "metadata": {}, "outputs": [], "source": [ - "import sys\n", "from __future__ import annotations" ] }, @@ -17,9 +16,7 @@ "id": "2d5517a8-7ff0-4340-90e3-3d37cf6ab11b", "metadata": {}, "outputs": [], - "source": [ - "from pathlib import Path" - ] + "source": [] }, { "cell_type": "code", @@ -38,21 +35,14 @@ "metadata": {}, "outputs": [], "source": [ - "from src.model_clay_eval import CLAYModule\n", - "import src.datamodule\n", - "#from src.datamodule import ClayDataset, ClayDataModule\n", - "from src.datamodule_eval_local import ClayDataset, ClayDataModule\n", - "import pandas as pd\n", - "import random\n", "import matplotlib.pyplot as plt\n", - "from torch.utils.data import DataLoader\n", "import numpy as np\n", - "import einops\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.metrics.pairwise import cosine_similarity\n", - "import rasterio as rio\n", - "from einops import rearrange, reduce\n", - "import torch" + "import torch\n", + "from einops import rearrange\n", + "\n", + "# from src.datamodule import ClayDataset, ClayDataModule\n", + "from src.datamodule_eval_local import ClayDataModule\n", + "from src.model_clay_eval import CLAYModule" ] }, { @@ -62,8 +52,10 @@ "metadata": {}, "outputs": [], "source": [ - "model = CLAYModule.load_from_checkpoint(\"../clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n", - "model.eval();" + "model = CLAYModule.load_from_checkpoint(\n", + " \"../clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.0\n", + ")\n", + "model.eval()" ] }, { @@ -524,11 +516,11 @@ "from einops import rearrange\n", "\n", "embeddings = emb[0]\n", - "embeddings = embeddings[:,:-2,:]\n", + "embeddings = embeddings[:, :-2, :]\n", "latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n", "latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n", "latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n", - "print(latent.shape)\n" + "print(latent.shape)" ] }, { @@ -590,7 +582,7 @@ } ], "source": [ - "plt.imshow(batch[\"pixels\"][0].permute(1,2,0)[:,:,1].detach().numpy(), cmap=\"bwr\")" + "plt.imshow(batch[\"pixels\"][0].permute(1, 2, 0)[:, :, 1].detach().numpy(), cmap=\"bwr\")" ] }, { @@ -611,9 +603,9 @@ } ], "source": [ - "fig, axs = plt.subplots(1,10,figsize=(10,5))\n", - "for i,ax in enumerate(axs.flatten()):\n", - " ax.imshow(latent[0][i+10].detach().numpy(), cmap=\"bwr\")" + "fig, axs = plt.subplots(1, 10, figsize=(10, 5))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " ax.imshow(latent[0][i + 10].detach().numpy(), cmap=\"bwr\")" ] }, { @@ -624,8 +616,8 @@ "outputs": [], "source": [ "encoder = torch.nn.Sequential(\n", - " torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n", - " )\n", + " torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n", + ")\n", "\n", "decoder = torch.nn.Sequential(\n", " torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n", @@ -640,7 +632,7 @@ " torch.nn.ReLU(inplace=True),\n", " torch.nn.Upsample(scale_factor=2),\n", " torch.nn.ConvTranspose2d(8, 1, kernel_size=3, padding=1),\n", - " torch.nn.Upsample(scale_factor=2)\n", + " torch.nn.Upsample(scale_factor=2),\n", ")\n", "\n", "\n", @@ -953,13 +945,14 @@ } ], "source": [ - "from pytorch_lightning import LightningModule, Trainer\n", "from einops import rearrange\n", + "from pytorch_lightning import LightningModule, Trainer\n", + "\n", "\n", "class UNet(torch.nn.Module):\n", " def __init__(self, in_channels, out_channels):\n", - " super().__init__() \n", - " \n", + " super().__init__()\n", + "\n", " self.decoder = torch.nn.Sequential(\n", " torch.nn.Conv2d(4608, 64, kernel_size=1, padding=0),\n", " torch.nn.Upsample(scale_factor=2),\n", @@ -973,11 +966,10 @@ " torch.nn.ReLU(inplace=True),\n", " torch.nn.Upsample(scale_factor=2),\n", " torch.nn.ConvTranspose2d(8, 1, kernel_size=3, padding=1),\n", - " torch.nn.Upsample(scale_factor=2)\n", + " torch.nn.Upsample(scale_factor=2),\n", " )\n", "\n", - "\n", - " def forward(self,x):\n", + " def forward(self, x):\n", " x = self.decoder(x)\n", " return x\n", "\n", @@ -987,7 +979,7 @@ " super().__init__()\n", " self.model = model\n", " self.datamodule = datamodule\n", - " \n", + "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", @@ -999,14 +991,16 @@ " batch[\"latlon\"] = batch[\"latlon\"].to(model_clay.device)\n", " emb = model_clay.model.encoder(batch)\n", " embeddings = emb[0]\n", - " embeddings = embeddings[:,:-2,:]\n", + " embeddings = embeddings[:, :-2, :]\n", " latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n", " latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n", " latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n", " prediction = self.model(latent)\n", " print(\"Prediction shape:\", prediction.shape)\n", " print(\"Label shape:\", y.shape)\n", - " loss = torch.nn.functional.binary_cross_entropy_with_logits(prediction.to(dtype=torch.float32), y)\n", + " loss = torch.nn.functional.binary_cross_entropy_with_logits(\n", + " prediction.to(dtype=torch.float32), y\n", + " )\n", " loss = torch.tensor(loss, requires_grad=True)\n", " self.log(\"train_loss\", loss)\n", " print(\"train_loss\", loss)\n", @@ -1019,7 +1013,7 @@ " batch[\"latlon\"] = batch[\"latlon\"].to(model_clay.device)\n", " emb = model_clay.model.encoder(batch)\n", " embeddings = emb[0]\n", - " embeddings = embeddings[:,:-2,:]\n", + " embeddings = embeddings[:, :-2, :]\n", " latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n", " latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n", " latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n", @@ -1037,20 +1031,24 @@ " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n", " return optimizer\n", - " \n", + "\n", " def train_dataloader(self):\n", " return self.datamodule.train_dataloader()\n", "\n", " def val_dataloader(self):\n", " return self.datamodule.val_dataloader()\n", "\n", + "\n", "dm = ClayDataModule(data_dir=data_dir, batch_size=2)\n", "dm.setup()\n", - "#val_dl = iter(dm.val_dataloader())\n", + "# val_dl = iter(dm.val_dataloader())\n", "\n", "model_unet = UNet(13, 1)\n", - "model_clay = CLAYModule.load_from_checkpoint(\"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n", - "model_clay.eval();\n", + "model_clay = CLAYModule.load_from_checkpoint(\n", + " \"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\",\n", + " mask_ratio=0.0,\n", + ")\n", + "model_clay.eval()\n", "segmentation_model = SegmentationModel(model_unet, dm)\n", "\n", "trainer = Trainer(max_epochs=3)\n", @@ -1133,19 +1131,19 @@ " model.eval()\n", " with torch.no_grad():\n", " for batch in dataloader:\n", - " #print(batch)\n", + " # print(batch)\n", " x_val, y_val = batch[\"pixels\"], batch[\"labels\"]\n", - " \n", + "\n", " emb = model_clay.model.encoder(batch)\n", " embeddings = emb[0]\n", - " embeddings = embeddings[:,:-2,:]\n", + " embeddings = embeddings[:, :-2, :]\n", " latent = rearrange(embeddings, \"b (g l) d -> b g l d\", g=6)\n", " latent = rearrange(latent, \"b g (h w) d -> b g h w d\", h=16, w=16)\n", " latent = rearrange(latent, \"b g h w d -> b (g d) h w\")\n", - " #y = y[0, :, :, :]\n", + " # y = y[0, :, :, :]\n", " prediction = model(latent)\n", - " #y_pred = model(batch)\n", - " #_, prediction = torch.max(y_pred, dim=1)\n", + " # y_pred = model(batch)\n", + " # _, prediction = torch.max(y_pred, dim=1)\n", "\n", " # Convert tensors to NumPy arrays for plotting\n", " x_val_np = x_val.cpu().numpy()\n", @@ -1170,7 +1168,7 @@ " )\n", " ),\n", " (1, 2, 0),\n", - " ) #.clip(0, 6000) / 6000\n", + " ) # .clip(0, 6000) / 6000\n", " ) # x_val_np[i], (1, 2, 0))) # Plot input images\n", " axes[i, 1].imshow(\n", " np.transpose(y_val_np[i], (1, 2, 0))\n", @@ -1179,10 +1177,14 @@ "\n", " plt.show()\n", "\n", - "model_clay = CLAYModule.load_from_checkpoint(\"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\", mask_ratio=0.)\n", + "\n", + "model_clay = CLAYModule.load_from_checkpoint(\n", + " \"/Users/lillythomas/Documents/work/clay/lt/benchmark/seg/clay-small-70MT-1100T-10E.ckpt\",\n", + " mask_ratio=0.0,\n", + ")\n", "model_clay.eval()\n", - "#val_dataloader = datamodule_floods.val_dataloader()\n", - "plot_predictions(model_unet, model_clay, dm.val_dataloader())\n" + "# val_dataloader = datamodule_floods.val_dataloader()\n", + "plot_predictions(model_unet, model_clay, dm.val_dataloader())" ] }, { diff --git a/notebooks/c2smsfloods_unet.ipynb b/notebooks/c2smsfloods_unet.ipynb index f1a51d7d..8aa54590 100644 --- a/notebooks/c2smsfloods_unet.ipynb +++ b/notebooks/c2smsfloods_unet.ipynb @@ -7,22 +7,16 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "import random\n", - "from datetime import datetime, timedelta\n", "\n", "import boto3\n", "import numpy\n", "import rasterio\n", "import rioxarray\n", - "import xarray as xr\n", - "from matplotlib import pyplot as plt\n", - "import pystac\n", - "import pystac_client\n", "import torch\n", - "from torch.utils.data import TensorDataset, DataLoader\n", + "from matplotlib import pyplot as plt\n", "from pytorch_lightning import LightningModule, Trainer\n", - "import torchvision.transforms as T\n", + "from torch.utils.data import DataLoader\n", "from torchvision.transforms import v2" ] }, @@ -33,7 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "def list_objects_recursive(client, bucket_name, prefix=''):\n", + "def list_objects_recursive(client, bucket_name, prefix=\"\"):\n", " \"\"\"\n", " List all objects (file keys) in an S3 bucket recursively under a specified prefix.\n", "\n", @@ -45,14 +39,14 @@ " Returns:\n", " - list: A list of file keys (object keys) found under the specified prefix.\n", " \"\"\"\n", - " paginator = client.get_paginator('list_objects_v2')\n", + " paginator = client.get_paginator(\"list_objects_v2\")\n", "\n", " page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)\n", "\n", " file_keys = []\n", " for page in page_iterator:\n", - " if 'Contents' in page:\n", - " file_keys.extend([obj['Key'] for obj in page['Contents']])\n", + " if \"Contents\" in page:\n", + " file_keys.extend([obj[\"Key\"] for obj in page[\"Contents\"]])\n", "\n", " return file_keys" ] @@ -69,6 +63,7 @@ "flood_events = []\n", "positions = []\n", "\n", + "\n", "def get_image_granules(bucket_name, prefix):\n", " \"\"\"\n", " Get granules of N-dim datacube and label images from an S3 bucket.\n", @@ -81,39 +76,42 @@ " - tuple: None.\n", " \"\"\"\n", " # Initialize Boto3 S3 client\n", - " s3 = boto3.client('s3')\n", + " s3 = boto3.client(\"s3\")\n", "\n", " # List objects in the specified prefix (directory) in the bucket\n", " files_in_s3 = list_objects_recursive(s3, bucket_name, prefix)\n", "\n", " # Filter S2 and S1 images\n", - " S1_labels = [i for i in files_in_s3 if 'LabelWater.tif' in i]\n", + " S1_labels = [i for i in files_in_s3 if \"LabelWater.tif\" in i]\n", " datacube_images = [f\"{i[:-15]}.tif\" for i in S1_labels]\n", - " \n", + "\n", " for i in datacube_images[0:100]:\n", - " position = '_'.join(i.split('/')[-1].split('_')[-3:-1])\n", + " position = \"_\".join(i.split(\"/\")[-1].split(\"_\")[-3:-1])\n", " positions.append(position)\n", - " flood_event = i.split('/')[-2]\n", - " flood_events.append(flood_event) \n", + " flood_event = i.split(\"/\")[-2]\n", + " flood_events.append(flood_event)\n", " # Load the image file from S3 directly into memory using rasterio\n", " obj = s3.get_object(Bucket=bucket_name, Key=i)\n", - " with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n", + " with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n", " with memfile.open() as dataset:\n", " data_array = rioxarray.open_rasterio(dataset)\n", - " #print(data_array.values)\n", + " # print(data_array.values)\n", " pair = [i, data_array.values]\n", " image_array_values.append(pair)\n", " for i in S1_labels[0:100]:\n", " # Load the image file from S3 directly into memory using rasterio\n", " obj = s3.get_object(Bucket=bucket_name, Key=i)\n", - " with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n", + " with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n", " with memfile.open() as dataset:\n", " data_array = rioxarray.open_rasterio(dataset)\n", - " #print(data_array.values)\n", + " # print(data_array.values)\n", " pair = [i, data_array.values]\n", " label_array_values.append(pair)\n", "\n", - "get_image_granules(bucket_name='clay-benchmark', prefix='c2smsfloods/datacube/chips_512/')" + "\n", + "get_image_granules(\n", + " bucket_name=\"clay-benchmark\", prefix=\"c2smsfloods/datacube/chips_512/\"\n", + ")" ] }, { @@ -132,6 +130,7 @@ ], "source": [ "import random\n", + "\n", "random.seed(9) # set a seed for reproducibility\n", "\n", "# put 1/3 of chips into the validation set\n", @@ -140,18 +139,27 @@ "train_chip_ids_values = [i for i in chip_ids_values if i not in val_chip_ids_values]\n", "train_images_values = []\n", "train_labels_values = []\n", - "for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n", + "for i, j, k in zip(\n", + " image_array_values, label_array_values, range(len(image_array_values))\n", + "):\n", " if k in train_chip_ids_values:\n", " train_images_values.append(i)\n", " train_labels_values.append(j)\n", "val_images_values = []\n", "val_labels_values = []\n", - "for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n", + "for i, j, k in zip(\n", + " image_array_values, label_array_values, range(len(image_array_values))\n", + "):\n", " if k in val_chip_ids_values:\n", " val_images_values.append(i)\n", " val_labels_values.append(j)\n", - " \n", - "print(len(train_images_values), len(val_images_values), len(train_labels_values), len(val_labels_values))" + "\n", + "print(\n", + " len(train_images_values),\n", + " len(val_images_values),\n", + " len(train_labels_values),\n", + " len(val_labels_values),\n", + ")" ] }, { @@ -171,7 +179,6 @@ "\n", " fig, axs = plt.subplots(1, 3, figsize=(30, 30))\n", "\n", - "\n", " rgb = (\n", " numpy.array(\n", " [\n", @@ -261,50 +268,80 @@ " 0.380075,\n", " 630.602233,\n", " ]\n", - " def __init__(self, train_images, val_images, train_labels, val_labels, batch_size=2, num_workers=4):\n", + "\n", + " def __init__(\n", + " self,\n", + " train_images,\n", + " val_images,\n", + " train_labels,\n", + " val_labels,\n", + " batch_size=2,\n", + " num_workers=4,\n", + " ):\n", " super().__init__()\n", " self.train_images = train_images\n", " self.val_images = val_images\n", " self.train_labels = train_labels\n", " self.val_labels = val_labels\n", " self.batch_size = batch_size\n", - " self.num_workers = num_workers \n", + " self.num_workers = num_workers\n", " self.tfm = v2.Compose(\n", " [\n", " v2.Normalize(mean=self.MEAN, std=self.STD),\n", " ]\n", " )\n", - " \n", - " self.train_dataset = self.get_dataset(self.train_images, self.train_labels, transform=self.tfm)\n", - " self.val_dataset = self.get_dataset(self.val_images, self.val_labels, transform=self.tfm)\n", + "\n", + " self.train_dataset = self.get_dataset(\n", + " self.train_images, self.train_labels, transform=self.tfm\n", + " )\n", + " self.val_dataset = self.get_dataset(\n", + " self.val_images, self.val_labels, transform=self.tfm\n", + " )\n", "\n", " def get_dataset(self, images, labels, transform=None):\n", - " #print(images[0][1].transpose(1,2,0).shape, labels[0][1].transpose(1,2,0).shape)\n", - " features = [torch.tensor(numpy.array(item[1])) for item in images] # Convert NumPy array to PyTorch tensor\n", - " targets = [torch.tensor(numpy.array(item[1])) for item in labels] \n", - " print(len(features),len(targets))\n", + " # print(images[0][1].transpose(1,2,0).shape, labels[0][1].transpose(1,2,0).shape)\n", + " features = [\n", + " torch.tensor(numpy.array(item[1])) for item in images\n", + " ] # Convert NumPy array to PyTorch tensor\n", + " targets = [torch.tensor(numpy.array(item[1])) for item in labels]\n", + " print(len(features), len(targets))\n", " if transform:\n", " # convert to float16 and normalize\n", " features = [transform(feature) for feature in features]\n", - " dataset = torch.utils.data.TensorDataset(torch.stack(features), torch.stack(targets))\n", + " dataset = torch.utils.data.TensorDataset(\n", + " torch.stack(features), torch.stack(targets)\n", + " )\n", " return dataset\n", "\n", " def train_dataloader(self):\n", " # Return the training DataLoader\n", - " #transform = transforms.Compose([transforms.ToTensor()])\n", - " train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True, num_workers=self.num_workers, shuffle=True)\n", + " # transform = transforms.Compose([transforms.ToTensor()])\n", + " train_loader = DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " drop_last=True,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " )\n", " return train_loader\n", "\n", " def val_dataloader(self):\n", " # Return the training DataLoader\n", - " #transform = transforms.Compose([transforms.ToTensor()])\n", - " val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, drop_last=True, num_workers=self.num_workers, shuffle=False)\n", + " # transform = transforms.Compose([transforms.ToTensor()])\n", + " val_loader = DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " drop_last=True,\n", + " num_workers=self.num_workers,\n", + " shuffle=False,\n", + " )\n", " return val_loader\n", - " \n", - " \n", - "datamodule_floods = DataModule_Floods(train_images_values, val_images_values, train_labels_values, val_labels_values)\n", - "datamodule_floods.setup(stage='fit')\n", - "\n" + "\n", + "\n", + "datamodule_floods = DataModule_Floods(\n", + " train_images_values, val_images_values, train_labels_values, val_labels_values\n", + ")\n", + "datamodule_floods.setup(stage=\"fit\")" ] }, { @@ -933,7 +970,7 @@ " torch.nn.ReLU(inplace=True),\n", " torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", " torch.nn.ReLU(inplace=True),\n", - " torch.nn.MaxPool2d(kernel_size=2, stride=2)\n", + " torch.nn.MaxPool2d(kernel_size=2, stride=2),\n", " )\n", " self.decoder = torch.nn.Sequential(\n", " torch.nn.ConvTranspose2d(64, 128, kernel_size=3, padding=1),\n", @@ -942,7 +979,7 @@ " torch.nn.ReLU(inplace=True),\n", " torch.nn.Upsample(scale_factor=2),\n", " torch.nn.Conv2d(512, out_channels, kernel_size=3, padding=1),\n", - " torch.nn.Softmax(dim=1)\n", + " torch.nn.Softmax(dim=1),\n", " )\n", "\n", " def forward(self, x):\n", @@ -952,6 +989,7 @@ " x = self.decoder(x1)\n", " return x\n", "\n", + "\n", "class SegmentationModel(LightningModule):\n", " def __init__(self, model, datamodule):\n", " super().__init__()\n", @@ -960,10 +998,10 @@ "\n", " def forward(self, x):\n", " return self.model(x)\n", - " \n", + "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", - " #x = torch.tensor(x, requires_grad=True)\n", + " # x = torch.tensor(x, requires_grad=True)\n", " y = y.to(dtype=torch.float32)\n", " y = y.squeeze()\n", " print(\"Shapes - x:\", x.shape, \"y:\", y.shape)\n", @@ -972,13 +1010,13 @@ " print(\"Prediction shape:\", prediction.shape)\n", " print(\"Label shape:\", y.shape)\n", " loss = torch.nn.functional.cross_entropy(prediction.to(dtype=torch.float32), y)\n", - " loss = torch.tensor(loss, requires_grad = True)\n", - " self.log('train_loss', loss)\n", + " loss = torch.tensor(loss, requires_grad=True)\n", + " self.log(\"train_loss\", loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", - " #x = torch.tensor(x, requires_grad=True)\n", + " # x = torch.tensor(x, requires_grad=True)\n", " y = y.to(dtype=torch.float32)\n", " y = y.squeeze()\n", " print(\"Shapes - x:\", x.shape, \"y:\", y.shape)\n", @@ -986,9 +1024,11 @@ " _, prediction = torch.max(y_pred, dim=1)\n", " print(\"Prediction shape:\", prediction.shape)\n", " print(\"Label shape:\", y.shape)\n", - " val_loss = torch.nn.functional.cross_entropy(prediction.to(dtype=torch.float32), y)\n", - " val_loss = torch.tensor(val_loss, requires_grad = True)\n", - " self.log('val_loss', val_loss)\n", + " val_loss = torch.nn.functional.cross_entropy(\n", + " prediction.to(dtype=torch.float32), y\n", + " )\n", + " val_loss = torch.tensor(val_loss, requires_grad=True)\n", + " self.log(\"val_loss\", val_loss)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n", @@ -1000,11 +1040,12 @@ " def val_dataloader(self):\n", " return self.datamodule.val_dataloader()\n", "\n", - "model = UNet(13, 2) \n", + "\n", + "model = UNet(13, 2)\n", "segmentation_model = SegmentationModel(model, datamodule_floods)\n", "\n", - "trainer = Trainer(max_epochs=3) \n", - "trainer.fit(segmentation_model)\n" + "trainer = Trainer(max_epochs=3)\n", + "trainer.fit(segmentation_model)" ] }, { @@ -1437,6 +1478,7 @@ "import torch\n", "from torch.utils.data import DataLoader\n", "\n", + "\n", "def plot_predictions(model, dataloader):\n", " model.eval()\n", " with torch.no_grad():\n", @@ -1456,22 +1498,36 @@ "\n", " for i in range(num_samples):\n", " print(numpy.unique(y_val_np[i]))\n", - " #print(numpy.stack((x_val_np[i][2,:,:], x_val_np[i][1,:,:], x_val_np[i][0,:,:])))\n", - " #axes[i, 0].imshow(numpy.transpose(numpy.stack((x_val_np[i][2,:,:], x_val_np[i][1,:,:], x_val_np[i][0,:,:])), (1, 2, 0)).clip(0, 3000) / 3000)#x_val_np[i], (1, 2, 0))) # Plot input images\n", - " axes[i, 0].imshow(numpy.transpose(numpy.stack((x_val_np[i][2,:,:], x_val_np[i][1,:,:], x_val_np[i][0,:,:])), (1, 2, 0)))#x_val_np[i], (1, 2, 0))) # Plot input images\n", - " axes[i, 1].imshow(numpy.transpose(y_val_np[i], (1, 2, 0))) # Plot ground truths\n", + " # print(numpy.stack((x_val_np[i][2,:,:], x_val_np[i][1,:,:], x_val_np[i][0,:,:])))\n", + " # axes[i, 0].imshow(numpy.transpose(numpy.stack((x_val_np[i][2,:,:], x_val_np[i][1,:,:], x_val_np[i][0,:,:])), (1, 2, 0)).clip(0, 3000) / 3000)#x_val_np[i], (1, 2, 0))) # Plot input images\n", + " axes[i, 0].imshow(\n", + " numpy.transpose(\n", + " numpy.stack(\n", + " (\n", + " x_val_np[i][2, :, :],\n", + " x_val_np[i][1, :, :],\n", + " x_val_np[i][0, :, :],\n", + " )\n", + " ),\n", + " (1, 2, 0),\n", + " )\n", + " ) # x_val_np[i], (1, 2, 0))) # Plot input images\n", + " axes[i, 1].imshow(\n", + " numpy.transpose(y_val_np[i], (1, 2, 0))\n", + " ) # Plot ground truths\n", " axes[i, 2].imshow(prediction_np[i]) # Plot model predictions\n", "\n", " plt.show()\n", "\n", + "\n", "val_dataloader = datamodule_floods.val_dataloader()\n", "\n", "# Load the trained model\n", - "#loaded_model = UNet(13, 2) # Initialize the model architecture\n", - "#loaded_model.load_state_dict(torch.load('path_to_your_trained_model.pth')) # Load trained weights\n", + "# loaded_model = UNet(13, 2) # Initialize the model architecture\n", + "# loaded_model.load_state_dict(torch.load('path_to_your_trained_model.pth')) # Load trained weights\n", "\n", "# Run predictions and plot results\n", - "plot_predictions(model, val_dataloader)\n" + "plot_predictions(model, val_dataloader)" ] }, { diff --git a/src/datamodule_eval.py b/src/datamodule_eval.py index 95999a2f..ee687ff0 100644 --- a/src/datamodule_eval.py +++ b/src/datamodule_eval.py @@ -2,17 +2,13 @@ LightningDataModule to load Earth Observation data from GeoTIFF files using rasterio. """ +import glob import math import os -import random -from pathlib import Path -from typing import List, Literal -import glob import lightning as L import numpy as np import rasterio -import rioxarray import torch import torchdata from torch.utils.data import DataLoader, Dataset @@ -21,6 +17,7 @@ os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR" os.environ["GDAL_HTTP_MERGE_CONSECUTIVE_RANGES"] = "YES" + class ClayDataset(Dataset): def __init__(self, chips_path, chips_label_path, transform=None): super().__init__() @@ -56,19 +53,22 @@ def normalize_latlon(self, lon, lat): return lon, lat def read_chip(self, chip_path, chip_path_label, date, bounds, centroid, epsg): - chip = chip_path # rasterio.open(chip_path) - chip_label = chip_path_label # rasterio.open(chip_path_label) + chip = chip_path # rasterio.open(chip_path) + chip_label = chip_path_label # rasterio.open(chip_path_label) # read timestep & normalize year, month, day = self.normalize_timestamp(date) # read lat,lon from UTM to WGS84 & normalize - lon, lat = centroid[0], centroid[1], # longitude, latitude + lon, lat = ( + centroid[0], + centroid[1], + ) # longitude, latitude lon, lat = self.normalize_latlon(lon, lat) return { "labels": chip_label, - "pixels": chip, #chip.read(), + "pixels": chip, # chip.read(), # Raw values "bbox": bounds, "epsg": epsg, @@ -77,10 +77,8 @@ def read_chip(self, chip_path, chip_path_label, date, bounds, centroid, epsg): "latlon": (lat, lon), "timestep": (year, month, day), } - - - def get_image_granules(self, chips_path, chips_label_path, idx): + def get_image_granules(self, chips_path, chips_label_path, idx): chip_path = chips_path[idx] chip_label_path = chips_label_path[idx] @@ -105,19 +103,59 @@ def get_image_granules(self, chips_path, chips_label_path, idx): epsg = chip_data_array.crs.to_epsg() chip_label_path_data_array = rasterio.open(chip_label_path) label_array_values = chip_label_path_data_array.read() - return image_array_values, label_array_values, flood_event, position, date, bounds, centroid, epsg, filename - + return ( + image_array_values, + label_array_values, + flood_event, + position, + date, + bounds, + centroid, + epsg, + filename, + ) + def get_benchmark_data(self, chips_path, chips_label_path, idx): - image_array_values, label_array_values, flood_events, positions, dates, bounds_, centroids, epsgs, filenames = \ - self.get_image_granules(chips_path, chips_label_path, idx) - return image_array_values, label_array_values, flood_events, positions, dates, bounds_, centroids, epsgs, filenames + ( + image_array_values, + label_array_values, + flood_events, + positions, + dates, + bounds_, + centroids, + epsgs, + filenames, + ) = self.get_image_granules(chips_path, chips_label_path, idx) + return ( + image_array_values, + label_array_values, + flood_events, + positions, + dates, + bounds_, + centroids, + epsgs, + filenames, + ) def __getitem__(self, idx): - #image_array_values, label_array_values, flood_events, positions, dates, bounds_, centroids, epsgs, filenames = \ + # image_array_values, label_array_values, flood_events, positions, dates, bounds_, centroids, epsgs, filenames = \ # self.get_benchmark_data(self.chips_path, self.chips_label_path) - image_array_values, label_array_values, flood_event, position, date, bounds, centroid, epsg, filename = \ - self.get_benchmark_data(self.chips_path, self.chips_label_path, idx) - cube = self.read_chip(image_array_values, label_array_values, date, bounds, centroid, epsg) + ( + image_array_values, + label_array_values, + flood_event, + position, + date, + bounds, + centroid, + epsg, + filename, + ) = self.get_benchmark_data(self.chips_path, self.chips_label_path, idx) + cube = self.read_chip( + image_array_values, label_array_values, date, bounds, centroid, epsg + ) # remove nans and convert to tensor cube["labels"] = torch.as_tensor(data=cube["labels"], dtype=torch.float32) @@ -130,7 +168,7 @@ def __getitem__(self, idx): try: cube["source_url"] = str(self.chip_path.absolute()) except AttributeError: - cube["source_url"] = filename #chip_path + cube["source_url"] = filename # chip_path if self.transform: # convert to float16 and normalize @@ -188,25 +226,37 @@ def __init__( self.split_ratio = 0.8 self.tfm = v2.Compose([v2.Normalize(mean=self.MEAN, std=self.STD)]) - def setup(self, stage='fit'): + def setup(self, stage="fit"): # Get list of GeoTIFF filepaths from s3 bucket or data/ folder - #if self.data_dir.startswith("s3://"): + # if self.data_dir.startswith("s3://"): # dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) # chips_path = list(dp.list_files_by_s3(masks="*.tif")) - #else: # if self.data_dir is a local data path - - chips_path = glob.glob(f"{self.data_dir}/**/*_rtc.tif") #list(Path(self.data_dir).glob("**/*_rtc.tif")) - chips_label_path = glob.glob(f"{self.data_dir}/**/*_LabelWater.tif")#list(Path(self.data_dir).glob("**/*_LabelWater.tif")) + # else: # if self.data_dir is a local data path + + chips_path = glob.glob( + f"{self.data_dir}/**/*_rtc.tif" + ) # list(Path(self.data_dir).glob("**/*_rtc.tif")) + chips_label_path = glob.glob( + f"{self.data_dir}/**/*_LabelWater.tif" + ) # list(Path(self.data_dir).glob("**/*_LabelWater.tif")) print(f"Total number of chips: {len(chips_path)}") - #print(f"All chips: {chips_path}") + # print(f"All chips: {chips_path}") if stage == "fit": - #random.shuffle(chips_path) + # random.shuffle(chips_path) split = int(len(chips_path) * self.split_ratio) - #print("Splits: ", chips_path[:split], chips_label_path[:split]) + # print("Splits: ", chips_path[:split], chips_label_path[:split]) - self.trn_ds = ClayDataset(chips_path=chips_path[:split], chips_label_path=chips_label_path[:split], transform=self.tfm) - self.val_ds = ClayDataset(chips_path=chips_path[split:], chips_label_path=chips_label_path[:split], transform=self.tfm) + self.trn_ds = ClayDataset( + chips_path=chips_path[:split], + chips_label_path=chips_label_path[:split], + transform=self.tfm, + ) + self.val_ds = ClayDataset( + chips_path=chips_path[split:], + chips_label_path=chips_label_path[:split], + transform=self.tfm, + ) elif stage == "predict": self.prd_ds = ClayDataset(chips_path=chips_path, transform=self.tfm) @@ -399,4 +449,4 @@ def predict_dataloader(self) -> torch.utils.data.DataLoader: dataset=self.datapipe_predict, batch_size=None, # handled in datapipe already num_workers=self.num_workers, - ) \ No newline at end of file + )