<a href="https://colab.research.google.com/drive/1m8LUoa1n7SDC6N5eqCGOTcC-nwPQfIoT"><img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open and Execute in Google Colaboratory"></a>

<br>

<a href="https://cloudsen12.github.io/"><img align="left" src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white"></a>


<br><br>

<!--COURSE_INFORMATION-->
<img align="left" style="padding-right:10px;" src="https://cloudsen12plus.github.io/assets/logo.webp" width=10% >


>>>> *This notebook is part of the paper [CloudSEN12+: The largest dataset of expert-labeled pixels for cloud and cloud shadow detection in Sentinel-2](https://cloudsen12.github.io/); the content is available [on GitHub](https://github.com/cloudsen12)* and released under the [CC0 1.0 Universal - Creative Commons](https://creativecommons.org/publicdomain/zero/1.0/deed.en) license.

<br>

- See our paper [here](https://www.sciencedirect.com/science/article/pii/S2352340924008163).

- See cloudSEN12 website [here](https://cloudsen12.github.io/).

- See cloudSEN12 in Science Data Bank [here](https://www.scidb.cn/en/detail?dataSetId=2036f4657b094edfbb099053d6024b08&version=V1).


- See cloudSEN12 in GitHub [here](https://github.com/cloudsen12).

- See cloudSEN12 in Google Earth Engine [here](https://samapriya.github.io/awesome-gee-community-datasets/projects/cloudsen12/).

- See CloudApp [here](https://cloudsen12.github.io/en/blog/cloudapp/).

The CloudSEN12 dataset and the pre-trained models are released under a [CC0 1.0 Universal - Creative Commons](https://creativecommons.org/publicdomain/zero/1.0/deed.en) license.

## **1. Create a DataModule**

*Using the streaming support of MLS STAC __will slow down the DataLoader__ \_\_getitem\_\_. For better I/O performance, consider downloading the data first.*

In [None]:
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import mlstac

# Create a DataLoader object.
class CoreDataset(torch.utils.data.DataLoader):
    def __init__(self, subset:pd.DataFrame):
        subset.reset_index(drop=True, inplace=True)
        self.subset = subset

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, index: int):
        # Retrieve the data from HuggingFace
        sample = mlstac.get_data(dataset=self.subset.iloc[index], quiet=False).squeeze()

        # Load the Sentinel-2 all bands
        # We set <0:32> to make it faster and run in CPU
        X = sample[0:13, :, :].astype(np.float32) / 10000

        # Load the target
        y = sample[13, :, :].astype(np.int64)

        return X, y

class CoreDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 4):
        super().__init__()

        # Load the metadata from the MLSTAC Collection file
        metadata = mlstac.load(snippet="isp-uv-es/CloudSEN12Plus").metadata

        # Split the metadata into train, validation and test sets
        self.train_dataset = metadata[(metadata["split"] == "train") & (metadata["label_type"] == "high") & (metadata["proj_shape"] == 509)]
        self.validation_dataset = metadata[(metadata["split"] == "validation") & (metadata["label_type"] == "high") & (metadata["proj_shape"] == 509)]
        self.test_dataset = metadata[(metadata["split"] == "test") & (metadata["label_type"] == "high") & (metadata["proj_shape"] == 509)]

        # Define the batch_size
        self.batch_size = batch_size

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=CoreDataset(self.train_dataset),
            batch_size=self.batch_size,
            shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=CoreDataset(self.validation_dataset),
            batch_size=self.batch_size
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=CoreDataset(self.test_dataset),
            batch_size=self.batch_size
        )

## **2. Define a Model**

In [None]:
import segmentation_models_pytorch as smp
import torch

class litmodel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = smp.Unet(encoder_name="mobilenet_v2", encoder_weights=None, classes=4, in_channels=13)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## **3. Define the Trainer**

In [None]:
# Define the callbacks
callbacks = [
    pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=3
    ),
    pl.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=10,
        mode="min"
    )
]

# Define the trainer
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=callbacks,
    accelerator="auto",
    precision="16-mixed"
)

# Define the datamodule
datamodule = CoreDataModule(batch_size=16)

# Define the model
model = litmodel()

# Start the training
trainer.fit(model=model, datamodule=datamodule)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:512: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | Unet             | 6.6 M  | train
1 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
6.6 M     Trainable params
0         Non-trainable params
6.6 M     Total params
26.529    Total estimated model params size (MB

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Reading datapoint: ROI_0009__20190513T180929_20190513T181344_T12SWH
Reading datapoint: ROI_0009__20190528T180921_20190528T182048_T12SWH
Reading datapoint: ROI_0009__20191027T180439_20191027T180526_T12SWH
Reading datapoint: ROI_0009__20200212T181451_20200212T181553_T12SWH


Training: |          | 0/? [00:00<?, ?it/s]

Reading datapoint: ROI_1400__20191004T131249_20191004T131243_T23LRL
Reading datapoint: ROI_11144__20181226T053229_20181226T053230_T45UVU
Reading datapoint: ROI_5145__20200114T201759_20200114T201801_T08VNM
Reading datapoint: ROI_10961__20190525T043701_20190525T044442_T47ULQ
Reading datapoint: ROI_0936__20191030T081009_20191030T082935_T35LPK
Reading datapoint: ROI_7063__20200224T075919_20200224T082412_T35JNK
Reading datapoint: ROI_2537__20190112T140049_20190112T141130_T20HNG
Reading datapoint: ROI_0391__20190618T123701_20190618T123701_T33XWF
Reading datapoint: ROI_1952__20200729T031539_20200729T032541_T50TLP
Reading datapoint: ROI_1333__20190830T165849_20190830T171517_T14QND
Reading datapoint: ROI_1057__20190921T155939_20190921T160251_T18UVB
Reading datapoint: ROI_6833__20200511T140101_20200511T140057_T21JTF
Reading datapoint: ROI_1200__20200505T151711_20200505T151755_T19PDN
Reading datapoint: ROI_7640__20190908T072619_20190908T075454_T36KWU
Reading datapoint: ROI_9251__20200122T084241_2

In [None]:
# run validation dataset
valid_metrics = trainer.validate(model, datamodule=datamodule, verbose=True)
print(valid_metrics)

In [None]:
# run test dataset
test_metrics = trainer.test(model, datamodule=datamodule, verbose=True)
print(test_metrics)

In [None]:
smp_model = model.model
# if push_to_hub=True, model will be saved to repository with this name
smp_model.save_pretrained('./unet_test')