This notebook uses a super res model to convert the CloudSEN12 high dataset to 5m spatial res

The only thing you will need to change in this is the base_dataset_dir to a local drive with 300 GB of available storage

In [None]:
import torch
from pathlib import Path
from spandrel import ModelLoader
from huggingface_hub import hf_hub_download
import rasterio as rio
from typing import Optional
from rasterio.transform import Affine
from multiprocessing import Pool
from tqdm.auto import tqdm

In [None]:
base_dataset_dir = Path("/media/nick/4TB Working 7/Datasets/OCM datasets")

In [None]:
cloudsen12_high_dir = base_dataset_dir / "CloudSEN12 high"

In [None]:
super_res_raw_dir = base_dataset_dir / "CloudSEN12 high super res raw"
super_res_raw_dir.mkdir(exist_ok=True)

In [None]:
super_res_tile_dir = base_dataset_dir / "CloudSEN12 high super res tiles"
super_res_tile_dir.mkdir(exist_ok=True)

In [None]:
cs12_high_train_l1c_images = list(
    cloudsen12_high_dir.glob("*train_509_high_image_l1c*.tif")
)
len(cs12_high_train_l1c_images)

In [None]:
inference_dtype = torch.bfloat16
patch_size = 509
overlap = 50
device = torch.device("cuda")
upscale_factor = 2

In [None]:
model_path = hf_hub_download(
    repo_id="Phips/2xNomosUni_esrgan_multijpg",
    filename="2xNomosUni_esrgan_multijpg.safetensors",
)

model = ModelLoader().load_from_file(model_path)
model.to(device).eval().to(inference_dtype)

In [None]:
def get_tensor_stats(
    input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Get the mean and std of a tensor across the last two dimensions."""
    channel_means = input.mean(dim=(-2, -1), keepdim=True)
    channel_stds = input.std(dim=(-2, -1), keepdim=True)
    max_val = input.max()
    min_val = input.min()
    return channel_means, channel_stds, max_val, min_val


def transfer_colour_stats(
    source: torch.Tensor,
    target: Optional[torch.Tensor] = None,
    target_std: Optional[torch.Tensor] = None,
    target_mean: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Transfer mean and std from target to source"""
    source_mean = source.mean(dim=(-2, -1), keepdim=True)
    source_std = source.std(dim=(-2, -1), keepdim=True)

    if target is not None:
        target_mean = target.mean(dim=(-2, -1), keepdim=True)
        target_std = target.std(dim=(-2, -1), keepdim=True)
    elif target_std is None or target_mean is None:
        raise ValueError("Either target or target_std and target_mean must be provided")

    normalised = (source - source_mean) / (source_std + 1e-8)
    transferred = normalised * target_std + target_mean

    return transferred

In [None]:
for image_path in tqdm(cs12_high_train_l1c_images):
    out_path = super_res_raw_dir / image_path.name.replace("image", "image_super_res")
    if out_path.exists():
        continue
    src = rio.open(image_path)
    image_array = src.read()
    profile = src.profile

    image_tensor = torch.from_numpy(image_array).to(device).to(inference_dtype)

    input_channel_mean, input_channel_std, max_scene_value, min_scene_value = (
        get_tensor_stats(image_tensor)
    )

    image_tensor = (image_tensor - min_scene_value) / max_scene_value

    pred_tensor = model(image_tensor.unsqueeze(0)).squeeze(0)

    pred_tensor = transfer_colour_stats(
        source=pred_tensor,
        target_std=input_channel_std,
        target_mean=input_channel_mean,
    )

    pred_array = pred_tensor.float().numpy(force=True).astype(image_array.dtype)

    output_transform = profile["transform"] * Affine.scale(1 / upscale_factor)

    output_profile = profile.copy()
    output_profile.update(
        {
            "dtype": pred_array.dtype,
            "count": pred_array.shape[0],
            "height": pred_array.shape[1],
            "width": pred_array.shape[2],
            "transform": output_transform,
        }
    )
    with rio.open(out_path, "w", **output_profile) as dst:
        dst.write(pred_array)


In [None]:
super_images = list(super_res_raw_dir.glob("*image_super_res*.tif"))
len(super_images)

In [None]:
# grab the top left of each image and mask and save it
# for super_image in tqdm(super_images):
def process_super_image(super_image):
    src = rio.open(super_image)
    profile = src.profile
    array = src.read()
    original_transform = profile["transform"]

    label_path = cloudsen12_high_dir / super_image.name.replace(
        "image_super_res", "label"
    ).replace("_l1c", "")
    label_crs = rio.open(label_path)
    label_array = label_crs.read()
    # use numpy to repeat the label on x andd y axes x2
    label_array = label_array.repeat(upscale_factor, axis=1).repeat(
        upscale_factor, axis=2
    )

    for col in [0, 1]:
        for row in [0, 1]:
            new_transform = original_transform * Affine.translation(
                col * patch_size, row * patch_size
            )
            out_path = super_res_tile_dir / super_image.name.replace(
                "image_super_res", f"image_super_res_tile_{row}_{col}"
            )
            if not out_path.exists():
                array_clip = array[
                    :,
                    row * patch_size : row * patch_size + patch_size,
                    col * patch_size : col * patch_size + patch_size,
                ]
                export_profile = profile.copy()
                export_profile.update(
                    {
                        "height": patch_size,
                        "width": patch_size,
                        "transform": new_transform,
                    }
                )
                with rio.open(out_path, "w", **export_profile) as dst:
                    dst.write(array_clip)

            label_out_path = super_res_tile_dir / label_path.name.replace(
                "label", f"label_super_res_tile_{row}_{col}"
            )
            if not label_out_path.exists():
                label_array_clip = label_array[
                    :,
                    row * patch_size : row * patch_size + patch_size,
                    col * patch_size : col * patch_size + patch_size,
                ]

                label_profile = label_crs.profile.copy()

                label_profile.update(
                    {
                        "height": patch_size,
                        "width": patch_size,
                        "transform": new_transform,
                    }
                )
                with rio.open(label_out_path, "w", **label_profile) as dst:
                    dst.write(label_array_clip)

In [None]:
with Pool(4) as pool:
    list(
        tqdm(
            pool.imap(process_super_image, super_images),
            total=len(super_images),
            desc="Processing super images",
        )
    )