diff --git a/notebooks/c2smsfloods_unet.ipynb b/notebooks/c2smsfloods_unet.ipynb index 48772706..7506b351 100644 --- a/notebooks/c2smsfloods_unet.ipynb +++ b/notebooks/c2smsfloods_unet.ipynb @@ -7,22 +7,16 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "import random\n", - "from datetime import datetime, timedelta\n", "\n", "import boto3\n", "import numpy\n", "import rasterio\n", "import rioxarray\n", - "import xarray as xr\n", - "from matplotlib import pyplot as plt\n", - "import pystac\n", - "import pystac_client\n", "import torch\n", - "from torch.utils.data import TensorDataset, DataLoader\n", + "from matplotlib import pyplot as plt\n", "from pytorch_lightning import LightningModule, Trainer\n", - "import torchvision.transforms as T" + "from torch.utils.data import DataLoader" ] }, { @@ -32,7 +26,7 @@ "metadata": {}, "outputs": [], "source": [ - "def list_objects_recursive(client, bucket_name, prefix=''):\n", + "def list_objects_recursive(client, bucket_name, prefix=\"\"):\n", " \"\"\"\n", " List all objects (file keys) in an S3 bucket recursively under a specified prefix.\n", "\n", @@ -44,14 +38,14 @@ " Returns:\n", " - list: A list of file keys (object keys) found under the specified prefix.\n", " \"\"\"\n", - " paginator = client.get_paginator('list_objects_v2')\n", + " paginator = client.get_paginator(\"list_objects_v2\")\n", "\n", " page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)\n", "\n", " file_keys = []\n", " for page in page_iterator:\n", - " if 'Contents' in page:\n", - " file_keys.extend([obj['Key'] for obj in page['Contents']])\n", + " if \"Contents\" in page:\n", + " file_keys.extend([obj[\"Key\"] for obj in page[\"Contents\"]])\n", "\n", " return file_keys" ] @@ -77,6 +71,7 @@ "flood_events = []\n", "positions = []\n", "\n", + "\n", "def get_image_granules(bucket_name, prefix):\n", " \"\"\"\n", " Get granules of N-dim datacube and label images from an S3 bucket.\n", @@ -89,40 +84,43 @@ " - tuple: None.\n", " \"\"\"\n", " # Initialize Boto3 S3 client\n", - " s3 = boto3.client('s3')\n", + " s3 = boto3.client(\"s3\")\n", "\n", " # List objects in the specified prefix (directory) in the bucket\n", " files_in_s3 = list_objects_recursive(s3, bucket_name, prefix)\n", "\n", " # Filter S2 and S1 images\n", - " #S2_labels = [i for i in files_in_s3_gt if '/s2/' in i and 'LabelWater.tif' in i]\n", - " S1_labels = [i for i in files_in_s3 if 'LabelWater.tif' in i]\n", + " # S2_labels = [i for i in files_in_s3_gt if '/s2/' in i and 'LabelWater.tif' in i]\n", + " S1_labels = [i for i in files_in_s3 if \"LabelWater.tif\" in i]\n", " datacube_images = [f\"{i[:-15]}.tif\" for i in S1_labels]\n", - " \n", + "\n", " for i in datacube_images[0:100]:\n", - " position = '_'.join(i.split('/')[-1].split('_')[-3:-1])\n", + " position = \"_\".join(i.split(\"/\")[-1].split(\"_\")[-3:-1])\n", " positions.append(position)\n", - " flood_event = i.split('/')[-2]\n", - " flood_events.append(flood_event) \n", + " flood_event = i.split(\"/\")[-2]\n", + " flood_events.append(flood_event)\n", " # Load the image file from S3 directly into memory using rasterio\n", " obj = s3.get_object(Bucket=bucket_name, Key=i)\n", - " with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n", + " with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n", " with memfile.open() as dataset:\n", " data_array = rioxarray.open_rasterio(dataset)\n", - " #print(data_array.values)\n", + " # print(data_array.values)\n", " pair = [i, data_array.values]\n", " image_array_values.append(pair)\n", " for i in S1_labels[0:100]:\n", " # Load the image file from S3 directly into memory using rasterio\n", " obj = s3.get_object(Bucket=bucket_name, Key=i)\n", - " with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n", + " with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n", " with memfile.open() as dataset:\n", " data_array = rioxarray.open_rasterio(dataset)\n", - " #print(data_array.values)\n", + " # print(data_array.values)\n", " pair = [i, data_array.values]\n", " label_array_values.append(pair)\n", "\n", - "get_image_granules(bucket_name='clay-benchmark', prefix='c2smsfloods/datacube/chips_512/')" + "\n", + "get_image_granules(\n", + " bucket_name=\"clay-benchmark\", prefix=\"c2smsfloods/datacube/chips_512/\"\n", + ")" ] }, { @@ -141,6 +139,7 @@ ], "source": [ "import random\n", + "\n", "random.seed(9) # set a seed for reproducibility\n", "\n", "# put 1/3 of chips into the validation set\n", @@ -149,18 +148,27 @@ "train_chip_ids_values = [i for i in chip_ids_values if i not in val_chip_ids_values]\n", "train_images_values = []\n", "train_labels_values = []\n", - "for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n", + "for i, j, k in zip(\n", + " image_array_values, label_array_values, range(len(image_array_values))\n", + "):\n", " if k in train_chip_ids_values:\n", " train_images_values.append(i)\n", " train_labels_values.append(j)\n", "val_images_values = []\n", "val_labels_values = []\n", - "for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n", + "for i, j, k in zip(\n", + " image_array_values, label_array_values, range(len(image_array_values))\n", + "):\n", " if k in val_chip_ids_values:\n", " val_images_values.append(i)\n", " val_labels_values.append(j)\n", - " \n", - "print(len(train_images_values), len(val_images_values), len(train_labels_values), len(val_labels_values))" + "\n", + "print(\n", + " len(train_images_values),\n", + " len(val_images_values),\n", + " len(train_labels_values),\n", + " len(val_labels_values),\n", + ")" ] }, { @@ -180,7 +188,6 @@ "\n", " fig, axs = plt.subplots(1, 3, figsize=(30, 30))\n", "\n", - "\n", " rgb = (\n", " numpy.array(\n", " [\n", @@ -251,40 +258,66 @@ ], "source": [ "class DataModule_Floods(LightningModule):\n", - " def __init__(self, train_images, val_images, train_labels, val_labels, batch_size=2, num_workers=4):\n", + " def __init__(\n", + " self,\n", + " train_images,\n", + " val_images,\n", + " train_labels,\n", + " val_labels,\n", + " batch_size=2,\n", + " num_workers=4,\n", + " ):\n", " super().__init__()\n", " self.train_images = train_images\n", " self.val_images = val_images\n", " self.train_labels = train_labels\n", " self.val_labels = val_labels\n", " self.batch_size = batch_size\n", - " self.num_workers = num_workers \n", - " \n", + " self.num_workers = num_workers\n", + "\n", " self.train_dataset = self.get_dataset(self.train_images, self.train_labels)\n", " self.val_dataset = self.get_dataset(self.val_images, self.val_labels)\n", "\n", " def get_dataset(self, images, labels):\n", - " print(images[0][1].transpose(1,2,0).shape, labels[0][1].transpose(1,2,0).shape)\n", - " features = [torch.from_numpy(numpy.array(item[1]).transpose(1,2,0)) for item in images] # Convert NumPy array to PyTorch tensor\n", - " targets = [torch.from_numpy(numpy.array(item[1]).transpose(1,2,0)) for item in labels] \n", - " print(len(features),len(targets))\n", + " print(\n", + " images[0][1].transpose(1, 2, 0).shape, labels[0][1].transpose(1, 2, 0).shape\n", + " )\n", + " features = [\n", + " torch.from_numpy(numpy.array(item[1]).transpose(1, 2, 0)) for item in images\n", + " ] # Convert NumPy array to PyTorch tensor\n", + " targets = [\n", + " torch.from_numpy(numpy.array(item[1]).transpose(1, 2, 0)) for item in labels\n", + " ]\n", + " print(len(features), len(targets))\n", " return features, targets\n", "\n", " def train_dataloader(self):\n", " # Return the training DataLoader\n", - " #transform = transforms.Compose([transforms.ToTensor()])\n", - " train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n", + " # transform = transforms.Compose([transforms.ToTensor()])\n", + " train_loader = DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=True,\n", + " )\n", " return train_loader\n", "\n", " def val_dataloader(self):\n", " # Return the training DataLoader\n", - " #transform = transforms.Compose([transforms.ToTensor()])\n", - " val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)\n", + " # transform = transforms.Compose([transforms.ToTensor()])\n", + " val_loader = DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.batch_size,\n", + " num_workers=self.num_workers,\n", + " shuffle=False,\n", + " )\n", " return val_loader\n", - " \n", - "datamodule_floods = DataModule_Floods(train_images_values, val_images_values, train_labels_values, val_labels_values)\n", - "datamodule_floods.setup(stage='fit')\n", - "\n" + "\n", + "\n", + "datamodule_floods = DataModule_Floods(\n", + " train_images_values, val_images_values, train_labels_values, val_labels_values\n", + ")\n", + "datamodule_floods.setup(stage=\"fit\")" ] }, { @@ -302,13 +335,13 @@ " nn.ReLU(inplace=True),\n", " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", - " nn.MaxPool2d(kernel_size=2, stride=2)\n", + " nn.MaxPool2d(kernel_size=2, stride=2),\n", " )\n", " self.decoder = nn.Sequential(\n", " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n", - " nn.Sigmoid()\n", + " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", @@ -318,6 +351,7 @@ " x = self.decoder(x1)\n", " return x\n", "\n", + "\n", "class SegmentationModel(LightningModule):\n", " def __init__(self, model, datamodule):\n", " super().__init__()\n", @@ -331,15 +365,15 @@ " x, y = batch # x: input image, y: ground truth mask\n", " y_pred = self.model(x)\n", " loss = torch.nn.functional.cross_entropy(y_pred, y)\n", - " self.log('train_loss', loss)\n", + " self.log(\"train_loss\", loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", - " print(len(batch)) #.shape)\n", + " print(len(batch)) # .shape)\n", " x, y = batch # x: input image, y: ground truth mask\n", " y_pred = self.model(x)\n", " val_loss = torch.nn.functional.cross_entropy(y_pred, y)\n", - " self.log('val_loss', val_loss)\n", + " self.log(\"val_loss\", val_loss)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n", @@ -351,11 +385,12 @@ " def val_dataloader(self):\n", " return self.datamodule.val_dataloader()\n", "\n", - "model = UNet(13, 1) \n", + "\n", + "model = UNet(13, 1)\n", "segmentation_model = SegmentationModel(model, datamodule_floods)\n", "\n", - "trainer = Trainer(max_epochs=10) \n", - "trainer.fit(segmentation_model)\n" + "trainer = Trainer(max_epochs=10)\n", + "trainer.fit(segmentation_model)" ] } ],