In [1]:
from ukis_pysat.raster import Image
from pathlib import Path
import io
import warnings
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from pathlib import Path
from rasterio.windows import Window
from rasterio.warp import reproject, Resampling, calculate_default_transform
import rasterio as rio
import torch
from tqdm import tqdm
from torch import optim
from src.logger.logger import create_logger
from src.dataset.dataset import get_dataloader
from src.models.unet import UNet
from src.models.unet_plus import UNetPlusPlus
from src.dataset.dataset import SegDataset
from src.train import (
    criterion,
    get_metrics,
    get_device,
    set_seed,
    configure_deterministic_behavior,
)


warnings.filterwarnings("ignore")


data_path = Path("/eodc/private/tuwgeo/users/mabdelaa/data/watmap_test_cases_final")
imgs_path = [path for path in data_path.iterdir() if path.is_dir()]
imgs_path

  check_for_updates()


KeyboardInterrupt: 

In [52]:
data_base_path = Path(
    "/eodc/private/tuwgeo/users/mabdelaa/data/watmap_test_cases_final"
)


def collect_images(subfolder: str, extensions=(".tif", ".tiff")):
    """Collect all images from a given subfolder name inside each directory of data_base_path."""
    return [
        img
        for path in data_base_path.iterdir()
        if path.is_dir()
        for img in (path / subfolder).iterdir()
        if img.suffix.lower() in extensions
    ]


def predictive_entropy_log2_reversed(mu, eps=1e-8):
    """Compute binary entropy in [0,1], reversed (1=confident, 0=uncertain)."""
    if isinstance(mu, np.ndarray):
        mu = torch.from_numpy(mu)
    H = -(mu * torch.log2(mu + eps) + (1 - mu) * torch.log2(1 - mu + eps))
    H = torch.clamp(H, 0.0, 1.0)
    return (1.0 - H).numpy()  # convert back to numpy


def read_image(path: Path) -> Image:
    with rio.open(path) as src:
        img = src.read()
        profile = src.profile

    return img, profile


def create_entropy_path(input_img: Path) -> Path:

    # Navigate the hierarchy
    imgs_num = input_img.parent.parent
    imgs_band = input_img.parent.name
    file_name = input_img.name

    # Build the new path in one go
    save_path = (
        imgs_num
        / imgs_band.replace("mu", "entropy")
        / file_name.replace("mu", "entropy")
    )

    return save_path


def save_image(path: Path, img: np.ndarray, profile: dict):
    profile.update(
        {
            "count": 1,
            "dtype": "float32",
            "driver": "GTiff",
            "compress": "lzw",
        }
    )
    path.parent.mkdir(parents=True, exist_ok=True)
    with rio.open(path, "w", **profile) as dst:
        dst.write(img.astype("float32"), 1)


def process_img(input_img: Path):

    img, profile = read_image(input_img)
    save_path = create_entropy_path(input_img)
    entropy = predictive_entropy_log2_reversed(img[0])
    save_image(save_path, entropy, profile)


s1_mu_imgs = collect_images("s1_mu")
s2_mu_imgs = collect_images("s2_mu")
mu_imgs = collect_images("mu")

all_mu_imgs = s1_mu_imgs + s2_mu_imgs + mu_imgs


for img in tqdm(all_mu_imgs, desc="Processing images"):
    process_img(img)

Processing images: 100%|██████████| 768/768 [08:18<00:00,  1.54it/s]
