Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 13, 2023
1 parent 98cbe84 commit 7eb5468
Showing 1 changed file with 87 additions and 52 deletions.
139 changes: 87 additions & 52 deletions notebooks/c2smsfloods_unet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,16 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"from datetime import datetime, timedelta\n",
"\n",
"import boto3\n",
"import numpy\n",
"import rasterio\n",
"import rioxarray\n",
"import xarray as xr\n",
"from matplotlib import pyplot as plt\n",
"import pystac\n",
"import pystac_client\n",
"import torch\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"from matplotlib import pyplot as plt\n",
"from pytorch_lightning import LightningModule, Trainer\n",
"import torchvision.transforms as T"
"from torch.utils.data import DataLoader"
]
},
{
Expand All @@ -32,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"def list_objects_recursive(client, bucket_name, prefix=''):\n",
"def list_objects_recursive(client, bucket_name, prefix=\"\"):\n",
" \"\"\"\n",
" List all objects (file keys) in an S3 bucket recursively under a specified prefix.\n",
"\n",
Expand All @@ -44,14 +38,14 @@
" Returns:\n",
" - list: A list of file keys (object keys) found under the specified prefix.\n",
" \"\"\"\n",
" paginator = client.get_paginator('list_objects_v2')\n",
" paginator = client.get_paginator(\"list_objects_v2\")\n",
"\n",
" page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)\n",
"\n",
" file_keys = []\n",
" for page in page_iterator:\n",
" if 'Contents' in page:\n",
" file_keys.extend([obj['Key'] for obj in page['Contents']])\n",
" if \"Contents\" in page:\n",
" file_keys.extend([obj[\"Key\"] for obj in page[\"Contents\"]])\n",
"\n",
" return file_keys"
]
Expand All @@ -77,6 +71,7 @@
"flood_events = []\n",
"positions = []\n",
"\n",
"\n",
"def get_image_granules(bucket_name, prefix):\n",
" \"\"\"\n",
" Get granules of N-dim datacube and label images from an S3 bucket.\n",
Expand All @@ -89,40 +84,43 @@
" - tuple: None.\n",
" \"\"\"\n",
" # Initialize Boto3 S3 client\n",
" s3 = boto3.client('s3')\n",
" s3 = boto3.client(\"s3\")\n",
"\n",
" # List objects in the specified prefix (directory) in the bucket\n",
" files_in_s3 = list_objects_recursive(s3, bucket_name, prefix)\n",
"\n",
" # Filter S2 and S1 images\n",
" #S2_labels = [i for i in files_in_s3_gt if '/s2/' in i and 'LabelWater.tif' in i]\n",
" S1_labels = [i for i in files_in_s3 if 'LabelWater.tif' in i]\n",
" # S2_labels = [i for i in files_in_s3_gt if '/s2/' in i and 'LabelWater.tif' in i]\n",
" S1_labels = [i for i in files_in_s3 if \"LabelWater.tif\" in i]\n",
" datacube_images = [f\"{i[:-15]}.tif\" for i in S1_labels]\n",
" \n",
"\n",
" for i in datacube_images[0:100]:\n",
" position = '_'.join(i.split('/')[-1].split('_')[-3:-1])\n",
" position = \"_\".join(i.split(\"/\")[-1].split(\"_\")[-3:-1])\n",
" positions.append(position)\n",
" flood_event = i.split('/')[-2]\n",
" flood_events.append(flood_event) \n",
" flood_event = i.split(\"/\")[-2]\n",
" flood_events.append(flood_event)\n",
" # Load the image file from S3 directly into memory using rasterio\n",
" obj = s3.get_object(Bucket=bucket_name, Key=i)\n",
" with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n",
" with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n",
" with memfile.open() as dataset:\n",
" data_array = rioxarray.open_rasterio(dataset)\n",
" #print(data_array.values)\n",
" # print(data_array.values)\n",
" pair = [i, data_array.values]\n",
" image_array_values.append(pair)\n",
" for i in S1_labels[0:100]:\n",
" # Load the image file from S3 directly into memory using rasterio\n",
" obj = s3.get_object(Bucket=bucket_name, Key=i)\n",
" with rasterio.io.MemoryFile(obj['Body'].read()) as memfile:\n",
" with rasterio.io.MemoryFile(obj[\"Body\"].read()) as memfile:\n",
" with memfile.open() as dataset:\n",
" data_array = rioxarray.open_rasterio(dataset)\n",
" #print(data_array.values)\n",
" # print(data_array.values)\n",
" pair = [i, data_array.values]\n",
" label_array_values.append(pair)\n",
"\n",
"get_image_granules(bucket_name='clay-benchmark', prefix='c2smsfloods/datacube/chips_512/')"
"\n",
"get_image_granules(\n",
" bucket_name=\"clay-benchmark\", prefix=\"c2smsfloods/datacube/chips_512/\"\n",
")"
]
},
{
Expand All @@ -141,6 +139,7 @@
],
"source": [
"import random\n",
"\n",
"random.seed(9) # set a seed for reproducibility\n",
"\n",
"# put 1/3 of chips into the validation set\n",
Expand All @@ -149,18 +148,27 @@
"train_chip_ids_values = [i for i in chip_ids_values if i not in val_chip_ids_values]\n",
"train_images_values = []\n",
"train_labels_values = []\n",
"for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n",
"for i, j, k in zip(\n",
" image_array_values, label_array_values, range(len(image_array_values))\n",
"):\n",
" if k in train_chip_ids_values:\n",
" train_images_values.append(i)\n",
" train_labels_values.append(j)\n",
"val_images_values = []\n",
"val_labels_values = []\n",
"for i,j,k in zip(image_array_values, label_array_values, range(len(image_array_values))):\n",
"for i, j, k in zip(\n",
" image_array_values, label_array_values, range(len(image_array_values))\n",
"):\n",
" if k in val_chip_ids_values:\n",
" val_images_values.append(i)\n",
" val_labels_values.append(j)\n",
" \n",
"print(len(train_images_values), len(val_images_values), len(train_labels_values), len(val_labels_values))"
"\n",
"print(\n",
" len(train_images_values),\n",
" len(val_images_values),\n",
" len(train_labels_values),\n",
" len(val_labels_values),\n",
")"
]
},
{
Expand All @@ -180,7 +188,6 @@
"\n",
" fig, axs = plt.subplots(1, 3, figsize=(30, 30))\n",
"\n",
"\n",
" rgb = (\n",
" numpy.array(\n",
" [\n",
Expand Down Expand Up @@ -251,40 +258,66 @@
],
"source": [
"class DataModule_Floods(LightningModule):\n",
" def __init__(self, train_images, val_images, train_labels, val_labels, batch_size=2, num_workers=4):\n",
" def __init__(\n",
" self,\n",
" train_images,\n",
" val_images,\n",
" train_labels,\n",
" val_labels,\n",
" batch_size=2,\n",
" num_workers=4,\n",
" ):\n",
" super().__init__()\n",
" self.train_images = train_images\n",
" self.val_images = val_images\n",
" self.train_labels = train_labels\n",
" self.val_labels = val_labels\n",
" self.batch_size = batch_size\n",
" self.num_workers = num_workers \n",
" \n",
" self.num_workers = num_workers\n",
"\n",
" self.train_dataset = self.get_dataset(self.train_images, self.train_labels)\n",
" self.val_dataset = self.get_dataset(self.val_images, self.val_labels)\n",
"\n",
" def get_dataset(self, images, labels):\n",
" print(images[0][1].transpose(1,2,0).shape, labels[0][1].transpose(1,2,0).shape)\n",
" features = [torch.from_numpy(numpy.array(item[1]).transpose(1,2,0)) for item in images] # Convert NumPy array to PyTorch tensor\n",
" targets = [torch.from_numpy(numpy.array(item[1]).transpose(1,2,0)) for item in labels] \n",
" print(len(features),len(targets))\n",
" print(\n",
" images[0][1].transpose(1, 2, 0).shape, labels[0][1].transpose(1, 2, 0).shape\n",
" )\n",
" features = [\n",
" torch.from_numpy(numpy.array(item[1]).transpose(1, 2, 0)) for item in images\n",
" ] # Convert NumPy array to PyTorch tensor\n",
" targets = [\n",
" torch.from_numpy(numpy.array(item[1]).transpose(1, 2, 0)) for item in labels\n",
" ]\n",
" print(len(features), len(targets))\n",
" return features, targets\n",
"\n",
" def train_dataloader(self):\n",
" # Return the training DataLoader\n",
" #transform = transforms.Compose([transforms.ToTensor()])\n",
" train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n",
" # transform = transforms.Compose([transforms.ToTensor()])\n",
" train_loader = DataLoader(\n",
" self.train_dataset,\n",
" batch_size=self.batch_size,\n",
" num_workers=self.num_workers,\n",
" shuffle=True,\n",
" )\n",
" return train_loader\n",
"\n",
" def val_dataloader(self):\n",
" # Return the training DataLoader\n",
" #transform = transforms.Compose([transforms.ToTensor()])\n",
" val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)\n",
" # transform = transforms.Compose([transforms.ToTensor()])\n",
" val_loader = DataLoader(\n",
" self.val_dataset,\n",
" batch_size=self.batch_size,\n",
" num_workers=self.num_workers,\n",
" shuffle=False,\n",
" )\n",
" return val_loader\n",
" \n",
"datamodule_floods = DataModule_Floods(train_images_values, val_images_values, train_labels_values, val_labels_values)\n",
"datamodule_floods.setup(stage='fit')\n",
"\n"
"\n",
"\n",
"datamodule_floods = DataModule_Floods(\n",
" train_images_values, val_images_values, train_labels_values, val_labels_values\n",
")\n",
"datamodule_floods.setup(stage=\"fit\")"
]
},
{
Expand All @@ -302,13 +335,13 @@
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(64, 64, kernel_size=3, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.MaxPool2d(kernel_size=2, stride=2)\n",
" nn.MaxPool2d(kernel_size=2, stride=2),\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(128, out_channels, kernel_size=3, padding=1),\n",
" nn.Sigmoid()\n",
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
Expand All @@ -318,6 +351,7 @@
" x = self.decoder(x1)\n",
" return x\n",
"\n",
"\n",
"class SegmentationModel(LightningModule):\n",
" def __init__(self, model, datamodule):\n",
" super().__init__()\n",
Expand All @@ -331,15 +365,15 @@
" x, y = batch # x: input image, y: ground truth mask\n",
" y_pred = self.model(x)\n",
" loss = torch.nn.functional.cross_entropy(y_pred, y)\n",
" self.log('train_loss', loss)\n",
" self.log(\"train_loss\", loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" print(len(batch)) #.shape)\n",
" print(len(batch)) # .shape)\n",
" x, y = batch # x: input image, y: ground truth mask\n",
" y_pred = self.model(x)\n",
" val_loss = torch.nn.functional.cross_entropy(y_pred, y)\n",
" self.log('val_loss', val_loss)\n",
" self.log(\"val_loss\", val_loss)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)\n",
Expand All @@ -351,11 +385,12 @@
" def val_dataloader(self):\n",
" return self.datamodule.val_dataloader()\n",
"\n",
"model = UNet(13, 1) \n",
"\n",
"model = UNet(13, 1)\n",
"segmentation_model = SegmentationModel(model, datamodule_floods)\n",
"\n",
"trainer = Trainer(max_epochs=10) \n",
"trainer.fit(segmentation_model)\n"
"trainer = Trainer(max_epochs=10)\n",
"trainer.fit(segmentation_model)"
]
}
],
Expand Down

0 comments on commit 7eb5468

Please sign in to comment.