In [2]:
import sys
sys.path.append('../src')
import shutil
import os
import s3fs
import fs
from tqdm import tqdm
import hvac
from minio import Minio
from utils.satellite_image import SatelliteImage
from osgeo import gdal
import geemap

Authentification à Google Earth Engine


In [3]:
# service_account = (
#     "slums-detection-sa@ee-insee-sentinel.iam.gserviceaccount.com"
# )
# credentials = ee.ServiceAccountCredentials(
#     service_account, "GCP_credentials.json"
# )

# # Initialize the library.
# ee.Initialize(credentials)

import ee

# Trigger the authentication flow.
ee.Authenticate()

# Initialize the library.
ee.Initialize()


Successfully saved authorization token.


In [4]:
def get_s2_sr_cld_col(aoi, start_date, end_date):
    # Import and filter S2 SR.
    s2_sr_col = (
        ee.ImageCollection("COPERNICUS/S2_SR")
        .filterBounds(aoi)
        .filterDate(start_date, end_date)
        .filter(ee.Filter.lte("CLOUDY_PIXEL_PERCENTAGE", CLOUD_FILTER))
    )

    # Import and filter s2cloudless.
    s2_cloudless_col = (
        ee.ImageCollection("COPERNICUS/S2_CLOUD_PROBABILITY")
        .filterBounds(aoi)
        .filterDate(start_date, end_date)
    )

    # Join the filtered s2cloudless collection to the SR collection by the 'system:index' property.
    return ee.ImageCollection(
        ee.Join.saveFirst("s2cloudless").apply(
            **{
                "primary": s2_sr_col,
                "secondary": s2_cloudless_col,
                "condition": ee.Filter.equals(
                    **{
                        "leftField": "system:index",
                        "rightField": "system:index",
                    }
                ),
            }
        )
    )

In [5]:
def add_cloud_bands(img):
    # Get s2cloudless image, subset the probability band.
    cld_prb = ee.Image(img.get("s2cloudless")).select("probability")

    # Condition s2cloudless by the probability threshold value.
    is_cloud = cld_prb.gt(CLD_PRB_THRESH).rename("clouds")

    # Add the cloud probability layer and cloud mask as image bands.
    return img.addBands(ee.Image([cld_prb, is_cloud]))

In [6]:
def add_shadow_bands(img):
    # Identify water pixels from the SCL band.
    not_water = img.select("SCL").neq(6)

    # Identify dark NIR pixels that are not water (potential cloud shadow pixels).
    SR_BAND_SCALE = 1e4
    dark_pixels = (
        img.select("B8")
        .lt(NIR_DRK_THRESH * SR_BAND_SCALE)
        .multiply(not_water)
        .rename("dark_pixels")
    )

    # Determine the direction to project cloud shadow from clouds (assumes UTM projection).
    shadow_azimuth = ee.Number(90).subtract(
        ee.Number(img.get("MEAN_SOLAR_AZIMUTH_ANGLE"))
    )

    # Project shadows from clouds for the distance specified by the CLD_PRJ_DIST input.
    cld_proj = (
        img.select("clouds")
        .directionalDistanceTransform(shadow_azimuth, CLD_PRJ_DIST * 10)
        .reproject(**{"crs": img.select(0).projection(), "scale": 100})
        .select("distance")
        .mask()
        .rename("cloud_transform")
    )

    # Identify the intersection of dark pixels with cloud shadow projection.
    shadows = cld_proj.multiply(dark_pixels).rename("shadows")

    # Add dark pixels, cloud projection, and identified shadows as image bands.
    return img.addBands(ee.Image([dark_pixels, cld_proj, shadows]))

In [7]:
def add_cld_shdw_mask(img):
    # Add cloud component bands.
    img_cloud = add_cloud_bands(img)

    # Add cloud shadow component bands.
    img_cloud_shadow = add_shadow_bands(img_cloud)

    # Combine cloud and shadow mask, set cloud and shadow as value 1, else 0.
    is_cld_shdw = (
        img_cloud_shadow.select("clouds")
        .add(img_cloud_shadow.select("shadows"))
        .gt(0)
    )

    # Remove small cloud-shadow patches and dilate remaining pixels by BUFFER input.
    # 20 m scale is for speed, and assumes clouds don't require 10 m precision.
    is_cld_shdw = (
        is_cld_shdw.focalMin(2)
        .focalMax(BUFFER * 2 / 20)
        .reproject(**{"crs": img.select([0]).projection(), "scale": 20})
        .rename("cloudmask")
    )

    # Add the final cloud-shadow mask to the image.
    return img_cloud_shadow.addBands(is_cld_shdw)

In [8]:
def apply_cld_shdw_mask(img):
    # Subset the cloudmask band and invert it so clouds/shadow are 0, else 1.
    not_cld_shdw = img.select("cloudmask").Not()

    # Subset reflectance bands and update their masks, return the result.
    return img.select("B.*").updateMask(not_cld_shdw)

Téléchargement en local puis mise en ligne de données

In [24]:
def export_s2_no_cloud(
    DOM,
    AOIs,
    EPSGs,
    start_date,
    end_date,
    cloud_filter,
    cloud_prb_thresh,
    nir_drk_thresh,
    cld_prj_dist,
    buffer,
):
        
    AOI = ee.Geometry.BBox(**AOIs[DOM])
    s2_sr_cld_col = get_s2_sr_cld_col(AOI, START_DATE, END_DATE)
    s2_sr_median = (
        s2_sr_cld_col.map(add_cld_shdw_mask).map(apply_cld_shdw_mask).median()
    )

    fishnet = geemap.fishnet(AOI, rows=4, cols=4, delta=0.5)
    geemap.download_ee_image_tiles(
        s2_sr_median,
        fishnet,
        f'{DOM}_{start_date[0:4]}/',
        prefix="data_",
        crs=f"EPSG:{EPSGs[DOM]}",
        scale=10,
        num_threads=50,
    )

    upload_satelliteImages(
        f'{DOM}_{start_date[0:4]}',
        f'projet-slums-detection/Donnees/SENTINEL2/{DOM.upper()}/TUILES_{start_date[0:4]}',
        250)
    
    shutil.rmtree(f"{DOM}_{start_date[0:4]}",ignore_errors=True)

    

Connexion à MinIO

In [29]:
def exportToMinio(image,rpath):
    client = hvac.Client(
            url='https://vault.lab.sspcloud.fr', token=os.environ["VAULT_TOKEN"]
        )

    secret = os.environ["VAULT_MOUNT"] + os.environ["VAULT_TOP_DIR"] + "/s3"
    mount_point, secret_path = secret.split("/", 1)
    secret_dict = client.secrets.kv.read_secret_version(
        path=secret_path, mount_point=mount_point
    )

    os.environ["AWS_ACCESS_KEY_ID"] = secret_dict["data"]["data"][
        "ACCESS_KEY_ID"
    ]
    os.environ["AWS_SECRET_ACCESS_KEY"] = secret_dict["data"]["data"][
        "SECRET_ACCESS_KEY"
    ]

    try:
        del os.environ['AWS_SESSION_TOKEN']
    except KeyError:
        pass

    fs = s3fs.S3FileSystem(
        client_kwargs={'endpoint_url': 'https://'+'minio.lab.sspcloud.fr'},
        key=os.environ["AWS_ACCESS_KEY_ID"],
        secret=os.environ["AWS_SECRET_ACCESS_KEY"]
    )
    
    return fs.put(image,rpath,True)

Mise en ligne de données préalablement téléchargées en local

In [30]:
def upload_satelliteImages(
    lpath,
    rpath,
    dim
):
    images_paths = os.listdir(lpath)

    for i in range(len(images_paths)):
        images_paths[i] = lpath+'/'+images_paths[i]

    list_satelliteImages = [
        SatelliteImage.from_raster(
            filename,
            dep = "973",
            n_bands = 12
        ) for filename in tqdm(images_paths)]

    splitted_list_images = [im for sublist in tqdm(list_satelliteImages) for im in sublist.split(dim)]

    for i in range(len(splitted_list_images)):
        image = splitted_list_images[i]

        transf = image.transform
        in_ds = gdal.Open(images_paths[1])
        proj = in_ds.GetProjection()

        array = image.array

        driver = gdal.GetDriverByName("GTiff")
        out_ds = driver.Create(f'image{i}.tif', array.shape[2], array.shape[1], array.shape[0], gdal.GDT_Float64)
        out_ds.SetGeoTransform([transf[2],transf[0],transf[1],transf[5],transf[3],transf[4]])
        out_ds.SetProjection(proj)

        for j in range(array.shape[0]):
            out_ds.GetRasterBand(j+1).WriteArray(array[j,:,:])

        out_ds = None
        
        exportToMinio(f'image{i}.tif',rpath)
        os.remove(f'image{i}.tif')

Filtres sur le téléchargement (CRS, emprise, caractéristiques du stack)

In [27]:
AOIs = {
    "Guadeloupe": {
        "west": -61.811124,
        "south": 15.828534,
        "east": -60.998518,
        "north": 16.523944,
    },
    "Martinique": {
        "west": -61.264617,
        "south": 14.378599,
        "east": -60.781573,
        "north": 14.899453,
    },
    "Mayotte": {
        "west": 45.013633,
        "south": -13.006619,
        "east": 45.308891,
        "north": -12.633022,
    },
    "Guyane": {
        "west": -52.883,
        "south": 4.148,
        "east": -51.813,
        "north": 5.426
    }
}

EPSGs = {"Guadeloupe": "4559", "Martinique": "4559", "Mayotte": "4471", "Guyane": "4235"}

START_DATE = "2022-05-01"
END_DATE = "2022-09-01"
CLOUD_FILTER = 60
CLD_PRB_THRESH = 40
NIR_DRK_THRESH = 0.15
CLD_PRJ_DIST = 2
BUFFER = 50

Téléchargement en local puis mise en lignes de données

In [31]:
# export_s2_no_cloud(
#     "Guadeloupe",
#     AOIs,
#     EPSGs,
#     START_DATE,
#     END_DATE,
#     CLOUD_FILTER,
#     CLD_PRB_THRESH,
#     NIR_DRK_THRESH,
#     CLD_PRJ_DIST,
#     BUFFER,
# ) 

# export_s2_no_cloud(
#     "Martinique",
#     AOIs,
#     EPSGs,
#     START_DATE,
#     END_DATE,
#     CLOUD_FILTER,
#     CLD_PRB_THRESH,
#     NIR_DRK_THRESH,
#     CLD_PRJ_DIST,
#     BUFFER,
# )

# export_s2_no_cloud(
#     "Mayotte",
#     AOIs,
#     EPSGs,
#     START_DATE,
#     END_DATE,
#     CLOUD_FILTER,
#     CLD_PRB_THRESH,
#     NIR_DRK_THRESH,
#     CLD_PRJ_DIST,
#     BUFFER,
# )

# export_s2_no_cloud(
#     "Guyane",
#     AOIs,
#     EPSGs,
#     START_DATE,
#     END_DATE,
#     CLOUD_FILTER,
#     CLD_PRB_THRESH,
#     NIR_DRK_THRESH,
#     CLD_PRJ_DIST,
#     BUFFER,
# )

Downloading 1/4: Mayotte_2022/data_1.tif


data_1.tif: |          | 0.00/73.8M (raw) [  0.0%] in 00:00 (eta:     ?)

Downloading 2/4: Mayotte_2022/data_2.tif


data_2.tif: |          | 0.00/73.8M (raw) [  0.0%] in 00:00 (eta:     ?)

Downloading 3/4: Mayotte_2022/data_3.tif


data_3.tif: |          | 0.00/73.9M (raw) [  0.0%] in 00:00 (eta:     ?)

Downloading 4/4: Mayotte_2022/data_4.tif


data_4.tif: |          | 0.00/73.8M (raw) [  0.0%] in 00:00 (eta:     ?)

100%|██████████| 4/4 [00:00<00:00,  5.34it/s]
100%|██████████| 4/4 [00:00<00:00, 6311.97it/s]


Upload de données déjà téléchargées pour 2021 et par DROM

In [None]:
upload_satelliteImages(
    "Guadeloupe_2021",
    'projet-slums-detection/Donnees/SENTINEL2/GUADELOUPE/TUILES_2021',
    250)

upload_satelliteImages(
    "Martinique_2021",
    'projet-slums-detection/Donnees/SENTINEL2/MARTINIQUE/TUILES_2021',
    250)

upload_satelliteImages(
    "Mayotte_2021",
    'projet-slums-detection/Donnees/SENTINEL2/MAYOTTE/TUILES_2021',
    250)

Upload de données déjà téléchargées pour 2022 et par DROM

In [None]:
upload_satelliteImages(
    "Guadeloupe_2022",
    'projet-slums-detection/Donnees/SENTINEL2/GUADELOUPE/TUILES_2022',
    250)

upload_satelliteImages(
    "Martinique_2022",
    'projet-slums-detection/Donnees/SENTINEL2/MARTINIQUE/TUILES_2022',
    250)

upload_satelliteImages(
    "Mayotte_2022",
    'projet-slums-detection/Donnees/SENTINEL2/MAYOTTE/TUILES_2022',
    250)

## Pipeline de train avec données Sentinel2

In [8]:
from utils.utils import update_storage_access
from datetime import datetime
from utils.labeler import RILLabeler, BDTOPOLabeler

In [39]:
config = {
    "tile size": 200,
    "source train": "SENTINEL2",
    "type labeler": "RIL",  # None if source train != PLEIADE
    "buffer size": 10,  # None if BDTOPO
    "year": 2022,
    "territory": "martinique",
    "dep": "972",
    "n bands": 3,
    "n channels train": 3,
}

config_train = {
    "lr": 0.0001,
    "momentum": 0.9,
    "module": "deeplabv3",
    "batch size": 2,
    "max epochs": 100,
}

# params
n_channel_train = config["n channels train"]

tile_size = config["tile size"]
n_bands = config["n bands"]
dep = config["dep"]
territory = config["territory"]
year = config["year"]
buffer_size = config["buffer size"]
source_train = config["source train"]
type_labeler = config["type labeler"]

module = config_train["module"]
batch_size = config_train["batch size"]

train_directory_name = "../splitted_data"

update_storage_access()
os.environ["MLFLOW_S3_ENDPOINT_URL"] = "https://minio.lab.sspcloud.fr"

# DL des données du territoire dont on se sert pour l'entraînement
# On peut faire une liste de couples années/territoire également
# Plus tard décliner avec change detection etc..
if type_labeler == "RIL":
    date = datetime.strptime(
        str(year).split("-")[-1] + "0101", "%Y%m%d"
    )
    labeler = RILLabeler(date, dep=dep, buffer_size=buffer_size)

if type_labeler == "BDTOPO":
    date = datetime.strptime(
        str(year).split("-")[-1] + "0101", "%Y%m%d"
    )
    labeler = BDTOPOLabeler(date, dep=dep)


### Load data

In [40]:
from utils.utils import get_root_path, get_environment

In [11]:
update_storage_access()
root_path = get_root_path()
environment = get_environment()

bucket = environment["bucket"]
path_s3 = environment["sources"]["SENTINEL2"][year][territory]
path_local = os.path.join(
    root_path, environment["local-path"]["SENTINEL2"][year][territory]
)

fs = s3fs.S3FileSystem(
    client_kwargs={"endpoint_url": "https://minio.lab.sspcloud.fr"}
)
print("download " + territory + " " + str(year) + " in " + path_local)
fs.download(
    rpath=f"{bucket}/{path_s3}", lpath=f"{path_local}", recursive=True
)  

download martinique 2022 in /home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022


Create and save segmentation masks

In [41]:
list_name = os.listdir(path_local)
list_name

['image227.tif',
 'image433.tif',
 'image140.tif',
 'image312.tif',
 'image340.tif',
 'image164.tif',
 'image489.tif',
 'image260.tif',
 'image43.tif',
 'image515.tif',
 'image474.tif',
 'image100.tif',
 'image364.tif',
 'image475.tif',
 'image418.tif',
 'image208.tif',
 'image376.tif',
 'image335.tif',
 'image185.tif',
 'image46.tif',
 'image98.tif',
 'image160.tif',
 'image467.tif',
 'image544.tif',
 'image103.tif',
 'image388.tif',
 'image121.tif',
 'image19.tif',
 'image446.tif',
 'image555.tif',
 'image175.tif',
 'image210.tif',
 'image215.tif',
 'image152.tif',
 'image120.tif',
 'image292.tif',
 'image21.tif',
 'image302.tif',
 'image134.tif',
 'image416.tif',
 'image144.tif',
 'image69.tif',
 'image485.tif',
 'image266.tif',
 'image254.tif',
 'image567.tif',
 'image315.tif',
 'image6.tif',
 'image241.tif',
 'image181.tif',
 'image569.tif',
 'image347.tif',
 'image478.tif',
 'image311.tif',
 'image178.tif',
 'image37.tif',
 'image14.tif',
 'image307.tif',
 'image559.tif',
 'image

In [42]:
list_path = [path_local + "/" + name for name in list_name]
list_path

['/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image227.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image433.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image140.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image312.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image340.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image164.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image489.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image260.tif',
 '/home/onyxia/work/detection-bidonvilles/notebooks/../data/SENTINEL2/MARTINIQUE/TUILES_2022/image43.tif',
 '/home/onyxia/work/detection

In [43]:
import numpy as np
from pathlib import Path
environment = get_environment()

In [44]:
  
output_masks_path = os.path.join(
    root_path, environment["local-path"]["SENTINEL2-LABELS"][year][territory]
)
if not os.path.exists(output_masks_path):
    os.makedirs(output_masks_path)
for path, file_name in zip(list_path, tqdm(list_name)):  # tqdm ici
    satellite_image = SatelliteImage.from_raster(
        file_path=path, dep=None, date=None, n_bands=n_bands
    )
    mask = labeler.create_segmentation_label(satellite_image)
    np.save(
        output_masks_path + "/" + Path(file_name).stem + ".npy", mask
    )

100%|█████████▉| 575/576 [00:43<00:00, 13.31it/s]


Build dataset

In [45]:
from datas.components.dataset import PleiadeDataset

In [46]:
labels = os.listdir(output_masks_path)
images = os.listdir(path_local)
list_path_labels = np.sort(
    [output_masks_path + "/" + name for name in labels]
)
list_path_images = np.sort(
    [path_local + "/"  + name for name in images]
)

dataset = PleiadeDataset(list_path_images, list_path_labels)

In [61]:
import albumentations as album
from albumentations.pytorch.transforms import ToTensorV2
import torch

dataset_test = PleiadeDataset(list_path_images[0], list_path_labels[0])
image_size = (250, 250)

transforms_augmentation = album.Compose(
    [
        album.Resize(300, 300, always_apply=True),
        album.RandomResizedCrop(
            *image_size, scale=(0.7, 1.0), ratio=(0.7, 1)
        ),
        album.HorizontalFlip(),
        album.VerticalFlip(),
        album.Normalize(mean=(0.5,0.406,0.456,0.485,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5),std=(0.225,0.229,0.225,0.224,0.225,0.225,0.225,0.225,0.225,0.225,0.225,0.225)),
        ToTensorV2(),
    ]
)

transforms_preprocessing = album.Compose(
    [
        album.Resize(*image_size, always_apply=True),
        album.Normalize(mean=(0.5,0.406,0.456,0.485,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5),std=(0.225,0.229,0.225,0.224,0.225,0.225,0.225,0.225,0.225,0.225,0.225,0.225)),
        ToTensorV2(),
    ]
)

# Instanciation modèle et paramètres d'entraînement
optimizer = torch.optim.SGD
optimizer_params = {
    "lr": config_train["lr"],
    "momentum": config_train["momentum"],
}
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler_params = {}
scheduler_interval = "epoch"


In [48]:
batch_size

2

In [49]:
n_channel_train = 3

In [50]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from utils.gestion_data import instantiate_module
from datas.datamodule import DataModule
from models.segmentation_module import SegmentationModule
import pytorch_lightning as pl

model = instantiate_module(module, n_channel_train)

batch_size = 2
data_module = DataModule(
    dataset=dataset,
    transforms_augmentation=transforms_augmentation,
    transforms_preprocessing=transforms_preprocessing,
    num_workers=1,
    batch_size=batch_size,
    dataset_test=dataset_test,
)

lightning_module = SegmentationModule(
    model=model,
    optimizer=optimizer,
    optimizer_params=optimizer_params,
    scheduler=scheduler,
    scheduler_params=scheduler_params,
    scheduler_interval=scheduler_interval,
)

checkpoint_callback = ModelCheckpoint(
    monitor="validation_loss", save_top_k=1, save_last=True, mode="max"
)

early_stop_callback = EarlyStopping(
    monitor="validation_loss", mode="max", patience=3
)

lr_monitor = LearningRateMonitor(logging_interval="step")

strategy = "auto"
list_callbacks = [lr_monitor, checkpoint_callback, early_stop_callback]

torch.cuda.empty_cache()

trainer = pl.Trainer(
    callbacks=list_callbacks,
    max_epochs=config_train["max epochs"],
    num_sanity_val_steps=2,
    strategy=strategy,
    log_every_n_steps=2,
)

trainer.fit(lightning_module, datamodule=data_module)
trainer.test(lightning_module, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type             | Params
-------------------------------------------
0 | model | DeepLabv3Module  | 61.0 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
61.0 M    Trainable params
0         Non-trainable params
61.0 M    Total params
243.965   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

RasterioIOError: Caught RasterioIOError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "rasterio/_base.pyx", line 308, in rasterio._base.DatasetBase.__init__
  File "rasterio/_base.pyx", line 219, in rasterio._base.open_dataset
  File "rasterio/_err.pyx", line 221, in rasterio._err.exc_wrap_pointer
rasterio._err.CPLE_OpenFailedError: '/' not recognized as a supported file format.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/mamba/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/mamba/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/mamba/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/onyxia/work/detection-bidonvilles/notebooks/../src/datas/components/dataset.py", line 114, in __getitem__
    img = SatelliteImage.from_raster(
  File "/home/onyxia/work/detection-bidonvilles/notebooks/../src/utils/satellite_image.py", line 180, in from_raster
    with rasterio.open(file_path) as raster:
  File "/opt/mamba/lib/python3.10/site-packages/rasterio/env.py", line 451, in wrapper
    return f(*args, **kwds)
  File "/opt/mamba/lib/python3.10/site-packages/rasterio/__init__.py", line 304, in open
    dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  File "rasterio/_base.pyx", line 310, in rasterio._base.DatasetBase.__init__
rasterio.errors.RasterioIOError: '/' not recognized as a supported file format.


Test more bands

In [58]:
from torch.utils.data import Dataset

class SentinelDataset(Dataset):
    """
    Custom Dataset class.
    """

    def __init__(
        self,
        list_paths_images,
        list_paths_labels,
        transforms = None,
        n_bands: int = 12
    ):
        """
        Constructor.

        Args:
            list_paths_images (List): list of path of the images
            list_paths_labels (List): list of paths containing the labels
            transforms (Compose) : list of transforms
        """
        self.list_paths_images = list_paths_images
        self.list_paths_labels = list_paths_labels
        self.transforms = transforms
        self.n_bands = n_bands

    def __getitem__(self, idx):
        """_summary_

        Args:
            idx (_type_): _description_

        Returns:
            _type_: _description_
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        pathim = self.list_paths_images[idx]
        pathlabel = self.list_paths_labels[idx]

        img = SatelliteImage.from_raster(
            file_path=pathim, dep=None, date=None, n_bands=self.n_bands
        ).array

        img = np.transpose(img.astype(float), [1, 2, 0])
        label = torch.tensor(np.load(pathlabel))

        if self.transforms:
            sample = self.transforms(image=img, label=label)
            img = sample["image"]
            label = sample["label"]
        else:
            img = torch.tensor(img.astype(float))
            img = img.permute([2, 0, 1])
            label = torch.tensor(label)

        img = img.type(torch.float)
        label = label.type(torch.LongTensor)
        dic = {"pathimage": pathim, "pathlabel": pathlabel}
        return img, label, dic

    def __len__(self):
        return len(self.list_paths_images)


In [59]:
n_channel_train = 12

In [62]:
model = instantiate_module(module, n_channel_train)
dataset = SentinelDataset(list_path_images, list_path_labels)
dataset_test = SentinelDataset(list_path_images[0], list_path_labels[0])

data_module = DataModule(
    dataset=dataset,
    transforms_augmentation=transforms_augmentation,
    transforms_preprocessing=transforms_preprocessing,
    num_workers=1,
    batch_size=batch_size,
    dataset_test=dataset_test,
)

lightning_module = SegmentationModule(
    model=model,
    optimizer=optimizer,
    optimizer_params=optimizer_params,
    scheduler=scheduler,
    scheduler_params=scheduler_params,
    scheduler_interval=scheduler_interval,
)

checkpoint_callback = ModelCheckpoint(
    monitor="validation_loss", save_top_k=1, save_last=True, mode="max"
)

early_stop_callback = EarlyStopping(
    monitor="validation_loss", mode="max", patience=3
)

lr_monitor = LearningRateMonitor(logging_interval="step")

strategy = "auto"
list_callbacks = [lr_monitor, checkpoint_callback, early_stop_callback]

torch.cuda.empty_cache()

trainer = pl.Trainer(
    callbacks=list_callbacks,
    max_epochs=config_train["max epochs"],
    num_sanity_val_steps=2,
    strategy=strategy,
    log_every_n_steps=2,
)

trainer.fit(lightning_module, datamodule=data_module)
trainer.test(lightning_module, datamodule=data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type             | Params
-------------------------------------------
0 | model | DeepLabv3Module  | 61.0 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
61.0 M    Trainable params
0         Non-trainable params
61.0 M    Total params
244.078   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])