In [2]:

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import rasterio as rio
import torch

from omnicloudmask import (
    predict_from_array,            # predicción desde un array (3 bandas: R,G,NIR)
    predict_from_load_func,        # predicción guardando GeoTIFFs a disco
    load_s2,                       # función de carga para Sentinel-2 (.SAFE)
    # load_ls8                     # Landsat 8
)

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


CUDA available: True
GPU: NVIDIA GeForce RTX 3050 Laptop GPU


## A mode

In [None]:
scene_paths = [
    Path(r"PATH\A\TU\ESCENA_S2_1.SAFE"),
    # Path(r"PATH\A\TU\ESCENA_S2_2.SAFE"),
]

In [None]:
use_gpu = torch.cuda.is_available()

pred_paths = predict_from_load_func(
    scene_paths=scene_paths,
    load_func=load_s2,
    # Opciones de rendimiento:
    batch_size=4 if use_gpu else 1,    # sube si tienes buena VRAM
    inference_device='bf16',           # acelera si la GPU/CPU lo soporta
    compile_models=True,               # compila el modelo para inferencia más rápida
    mosaic_device='cuda' if use_gpu else 'cpu',
)
print("Máscaras generadas (GeoTIFF):")
for p in pred_paths:
    print(" -", p)

In [None]:
if len(pred_paths) > 0:
    with rio.open(pred_paths[0]) as src:
        mask_full = src.read(1)  # 0..3 (Clear, Thick, Thin, Shadow)
        print("shape:", mask_full.shape, "CRS:", src.crs, "transform:", src.transform)
    # Muestra un recorte (por ejemplo, centro 1024x1024 si aplica)
    h, w = mask_full.shape
    i0, i1 = h//4, min(h//4 + 1024, h)
    j0, j1 = w//4, min(w//4 + 1024, w)
    plt.figure(figsize=(6,6))
    plt.imshow(mask_full[i0:i1, j0:j1], vmin=0, vmax=3)
    plt.title("OmniCloudMask — preview (S2)"); plt.colorbar(label="class (0..3)")
    plt.tight_layout(); plt.show()

## B mode

### Alt 1

In [None]:
tif_3band = Path(r"PATH\A\TU\RASTER_3BAND_RGNIR.tif")   # <-- cámbialo si lo vas a usar
save_mask_tif = Path(r"PATH\A\SALIDA\mask_ocm.tif")

if tif_3band.exists():
    with rio.open(tif_3band) as src:
        # Asegura que tienes 3 bandas en orden (R,G,NIR)
        arr = src.read()  # shape: (bands, H, W)
        assert arr.shape[0] >= 3, "Se esperan al menos 3 bandas (R,G,NIR)"
        # Si no están en orden, reordena aquí.
        rgbn = arr[:3, :, :].astype(np.float32)

        # Normalización: OmniCloudMask hace normalización dinámica multi-sensor,
        # así que no es estrictamente necesario escalar aquí si las bandas son reflectancias/TOA.
        # Predicción (devuelve numpy array máscara)
        pred_mask = predict_from_array(rgbn)

        # Guardar a GeoTIFF con la misma georreferenciación:
        profile = src.profile
        profile.update(
            count=1, dtype=rasterio.uint8, compress="deflate", predictor=2
        )

    with rio.open(save_mask_tif, "w", **profile) as dst:
        dst.write(pred_mask.astype(np.uint8), 1)

    print("Máscara guardada en:", save_mask_tif)

    # Vista rápida
    plt.figure(figsize=(6,5))
    plt.imshow(pred_mask, vmin=0, vmax=3)
    plt.title("OmniCloudMask — local 3-band mask"); plt.colorbar(label="class (0..3)")
    plt.tight_layout(); plt.show()

### Alt 2

In [None]:
R_path = Path(r"PATH\A\TU\RED.tif")
G_path = Path(r"PATH\A\TU\GREEN.tif")
N_path = Path(r"PATH\A\TU\NIR.tif")
save_mask_sep = Path(r"PATH\A\SALIDA\mask_ocm_from_3tifs.tif")

if R_path.exists() and G_path.exists() and N_path.exists():
    with rio.open(R_path) as rsrc:
        prof = rsrc.profile
        H, W = rsrc.height, rsrc.width
        transform, crs = rsrc.transform, rsrc.crs
        red = rsrc.read(1).astype(np.float32)

    # Reproyecta G y N a la malla de R
    def reproject_to(ref_path, src_path):
        with rio.open(ref_path) as ref, rio.open(src_path) as src:
            dest = np.empty((ref.height, ref.width), dtype=np.float32)
            rio.warp.reproject(
                source=rio.band(src, 1),
                destination=dest,
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=ref.transform,
                dst_crs=ref.crs,
                resampling=rio.warp.Resampling.bilinear,
            )
        return dest

    green = reproject_to(R_path, G_path)
    nir   = reproject_to(R_path, N_path)

    stack_3 = np.stack([red, green, nir], axis=0)  # (3, H, W)
    pred_mask = predict_from_array(stack_3)

    prof.update(count=1, dtype=rio.uint8, compress="deflate", predictor=2)
    with rio.open(save_mask_sep, "w", **prof) as dst:
        dst.write(pred_mask.astype(np.uint8), 1)

    print("Máscara guardada en:", save_mask_sep)