# Convert raster to tiled ML-ready data

ers and tif files are common for geoscience. We need a quick way to tile these and save them to our standard numpy tile dataset.

This notebook defines a function which:
1. converts a tif to a pytorch tensor
1. pads any nan areas with reflection padding 
1. unfolds the tensor in each direction (read torch .fold() / .unfold())
1. stacks those into a "batch" of tiles
1. generates a selection of indices for validation and training data
1. saves each to a seperate train/val folder
1. Generates some QA/QC figures

In [8]:
# import logging
from pathlib import Path

import colorcet as cc
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
import tifffile
from PIL import Image


def img_to_tiles(
    lr_raster_path,
    hr_raster_path,
    lr_out_prefix="lr_tiles",
    hr_out_prefix="hr_tiles",
    lr_s=32,  # lr tile res
    hr_s=128,  # hr tile res
    nan_val=-999_999,
    ext="npy",
    norm=False,
    single_output_folder=False,
):

    """
    Specify a LR and HR image pair. This will tile and save them to tensor
    arrays or image tiles ready to load in pytorch.

    Normalisation is hacky and requires further work.

    """
    if norm:
        max_ = 1000  # Aus Magmap v7 histogram arbitrary clip/stats
        min_ = -1000
        mean_ = 0
        std_ = 250

        # norm = lambda i: (i + 4403.574) / 18454.907 # values from training
        # unnorm = lambda i: (i * 18454.907) - 4403.574

        norm = lambda i: (i - min_) / (max_ - min_)  # values from training
        unnorm = lambda i: (i * (max_ - min_)) + min_

        print(f"Using {max_=}, {min_=}, {mean_=}, {std_=} for min max norm")

    else:
        norm = unnorm = lambda i: i  # NULL OP
        print("Not Normalising ...")
        # print("Using temp /8 norm for PPRDC, which outputs in +- 7. ish")

    if Path(lr_raster_path).suffix == ".tif":
        lr = tifffile.imread(lr_raster_path)
        hr = tifffile.imread(hr_raster_path)
    elif Path(lr_raster_path).suffix == ".ers":
        lr = np.array(rio.open(lr_raster_path).read(1))
        hr = np.array(rio.open(hr_raster_path).read(1))

    lr[lr == nan_val] = np.nan
    hr[hr == nan_val] = np.nan

    lr_tensor = torch.as_tensor(norm(lr), dtype=torch.float32).unsqueeze(0)
    hr_tensor = torch.as_tensor(norm(hr), dtype=torch.float32).unsqueeze(0)

    # lr_original_extent = lr_tensor.shape
    # hr_original_extent = hr_tensor.shape

    # Expand raster size to a common multiple of lr size (which is a multiple of hr)

    # padded_w = (lr_tensor.shape[0] + (lr_s - (lr_tensor.shape[0] % lr_s)))
    # padded_h = (lr_tensor.shape[1] + (lr_s - (lr_tensor.shape[1] % lr_s)))

    lr_tensor = torch.nn.functional.pad(
        lr_tensor,
        (
            0,
            (lr_s - (lr_tensor.shape[1] % lr_s)),
            0,
            (lr_s - (lr_tensor.shape[0] % lr_s)),
        ),
        mode="reflect",
    )
    hr_tensor = torch.nn.functional.pad(
        hr_tensor,
        (
            0,
            (hr_s - (hr_tensor.shape[1] % hr_s)),
            0,
            (hr_s - (hr_tensor.shape[0] % hr_s)),
        ),
        mode="reflect",
    )

    hr_tensor = hr_tensor[0]
    lr_tensor = lr_tensor[0]

    # Math to determine and execute the tiling process
    lr_tiles_per_row = lr_tensor.shape[1] // lr_s
    hr_tiles_per_row = hr_tensor.shape[1] // hr_s
    lr_tiles_per_column = lr_tensor.shape[0] // lr_s
    hr_tiles_per_column = hr_tensor.shape[0] // hr_s

    lr_patches = lr_tensor.unfold(0, lr_s, lr_s).unfold(1, lr_s, lr_s)
    hr_patches = hr_tensor.unfold(0, hr_s, hr_s).unfold(1, hr_s, hr_s)

    lr_patches = lr_patches.contiguous().view(
        lr_tiles_per_row * lr_tiles_per_column, -1, lr_s, lr_s
    )
    hr_patches = hr_patches.contiguous().view(
        hr_tiles_per_row * hr_tiles_per_column, -1, hr_s, hr_s
    )

    print(f"{len(lr_patches)=}")
    print(f"{len(hr_patches)=}")

    print(f"{lr_patches.shape=}")
    print(f"{hr_patches.shape=}")

    #  Drop tiles that contain mostly/any nan values, convert rest to some value
    allowed_nan_pct = 0.05
    nan_fill_val = mean_  # nan_val
    valid_mask = (
        torch.count_nonzero(lr_patches.isnan(), dim=(2, 3)) / lr_s ** 2
    ) <= allowed_nan_pct

    lr_patches_masked = lr_patches[valid_mask]
    hr_patches_masked = hr_patches[valid_mask]  # HR and LR indices need to match
    print(
        f"Dropped {sum(~valid_mask).item()} mostly NaN tiles (> {allowed_nan_pct*100}% nan)"
    )
    # nan_val = torch.tensor(0)  # np.nanmean(hr_patches_masked))
    # lr_patches_masked[lr_patches_masked == torch.nan] = nan_val
    # hr_patches_masked[hr_patches_masked == torch.nan] = nan_val
    hr_patches_masked = torch.nan_to_num(hr_patches_masked, nan=nan_fill_val)
    lr_patches_masked = torch.nan_to_num(lr_patches_masked, nan=nan_fill_val)
    print(f"Reverted any NaNs in remaining tiles to {nan_val}")

    ## PPDRC handling
    # # Same for excessive zeros, included at border is Nan value is 0.
    # allowed_zero_pct = 0.10
    # print(f"{(lr_patches_masked == 0).shape=}")

    # valid_mask = (
    #     torch.count_nonzero(lr_patches_masked == 0, dim=(1, 2)) / lr_s ** 2
    # ) <= allowed_zero_pct
    # lr_patches_masked = lr_patches_masked[valid_mask]
    # hr_patches_masked = hr_patches_masked[valid_mask]  # HR and LR indices need to match
    # print(
    #     f"Dropped {sum(~valid_mask).item()} mostly zero tiles (> {allowed_zero_pct*100}% Zero valued)"
    # )

    # Random split the train/val tiles for this dataset
    from numpy.random import default_rng

    rng = default_rng(seed=21)

    val_pct = 0.15  # 15% Val, 85% Train
    val_num = int(np.round(val_pct * len(lr_patches_masked)))  # len(lr_patches)
    val_indices = sorted(
        rng.choice(len(lr_patches_masked), size=val_num, replace=False)
    )
    print(f"{val_indices=}")

    train_indices = [i for i in range(len(lr_patches_masked)) if i not in val_indices]
    print(f"{train_indices=}")
    print(
        f"There are {sum(i in train_indices for i in val_indices)} val indices in your train indices :)"
    )

    lr_patches_train = lr_patches_masked[train_indices].numpy().astype(np.float32)
    hr_patches_train = hr_patches_masked[train_indices].numpy().astype(np.float32)
    lr_patches_val = lr_patches_masked[val_indices].numpy().astype(np.float32)
    hr_patches_val = hr_patches_masked[val_indices].numpy().astype(np.float32)

    print(f"{len(lr_patches_train)=}")
    print(f"{len(lr_patches_val)=}")

    print("None of these values should show nans!")
    print(f"{np.min(lr_patches_train)=}")
    print(f"{np.max(lr_patches_train)=}")
    print(f"{np.mean(lr_patches_train)=}")
    print(f"{np.std(lr_patches_train)=}")
    # print(f"{np.nanmin(hr_patches_train)=}")
    # print(f"{np.nanmax(hr_patches_train)=}")

    if single_output_folder:
        single_output_folder = Path(single_output_folder)
        train_dir = single_output_folder / "train"
        val_dir = single_output_folder / "val"
    else:
        train_dir = Path(hr_raster_path).parent / ext / "train"
        val_dir = Path(hr_raster_path).parent / ext / "val"

    (train_dir / "hr").mkdir(parents=True, exist_ok=True)
    (train_dir / "lr").mkdir(parents=True, exist_ok=True)
    (val_dir / "hr").mkdir(parents=True, exist_ok=True)
    (val_dir / "lr").mkdir(parents=True, exist_ok=True)

    hr_file_name = f"hr/{Path(hr_raster_path).stem}_hr"
    lr_file_name = f"lr/{Path(lr_raster_path).stem}_lr"

    if "npy" in ext:
        np.save(train_dir / f"{hr_file_name}.{ext}", hr_patches_train)
        np.save(train_dir / f"{lr_file_name}.{ext}", lr_patches_train)
        np.save(val_dir / f"{hr_file_name}.{ext}", hr_patches_val)
        np.save(val_dir / f"{lr_file_name}.{ext}", lr_patches_val)

    if "tif" in ext:
        for i in range(len(lr_patches_train)):
            tifffile.imsave(
                train_dir / f"{hr_file_name}_{i}.{ext}", hr_patches_train[i]
            )
            tifffile.imsave(
                train_dir / f"{lr_file_name}_{i}.{ext}", lr_patches_train[i]
            )

        for i in range(len(lr_patches_val)):
            tifffile.imsave(val_dir / f"{hr_file_name}_{i}.{ext}", hr_patches_val[i])
            tifffile.imsave(val_dir / f"{lr_file_name}_{i}.{ext}", lr_patches_val[i])

        print(hr_patches_val[i].shape)
        print(lr_patches_val[i].shape)

    print(f"\nSaved to {val_dir.parent.absolute()}")

    return lr_patches, hr_patches, val_indices


In [9]:
# survey_search = "*.tif"
survey_search = "P*"

for survey_path in Path(r"C:\Luke\data\Paper_2").glob(survey_search):
    if not survey_path.is_dir():
        continue
    print(survey_path)
    root = survey_path

    # lr_patches, hr_patches, val_indices = tifs_to_tensors(
    #     f"{root}/{survey_name}_x4_200_LR.tif",
    #     f"{root}/{survey_name}_x1_50_HR.tif",
    #     ext="tif",
    #     norm=True,
    # )

    lr_patches, hr_patches, val_indices = img_to_tiles(
        hr_raster_path=f"{next(root.glob('*1.ers'))}",
        lr_raster_path=f"{next(root.glob('*4.ers'))}",
        # hr_raster_path=next(survey_path.glob("*0200.tif")),
        # lr_raster_path=next(survey_path.glob("*0050.tif")),
        ext="tif",
        norm=True,
        single_output_folder="C:/Luke/data/Paper_2/lr64_combined",
        nan_val=-999999,
        lr_s=64,  # lr tile res
        hr_s=256,  # hr tile res
    )


C:\Luke\data\Paper_2\P578
Using max_=1000, min_=-1000, mean_=0, std_=250 for min max norm
len(lr_patches)=99
len(hr_patches)=99
lr_patches.shape=torch.Size([99, 1, 64, 64])
hr_patches.shape=torch.Size([99, 1, 256, 256])
Dropped 0 mostly NaN tiles (> 5.0% nan)
Reverted any NaNs in remaining tiles to -999999
val_indices=[8, 23, 25, 27, 31, 33, 41, 53, 58, 59, 63, 67, 93, 94, 97]
train_indices=[0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 26, 28, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 60, 61, 62, 64, 65, 66, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 95, 96, 98]
There are 0 val indices in your train indices :)
len(lr_patches_train)=84
len(lr_patches_val)=15
None of these values should show nans!
np.min(lr_patches_train)=-0.064983584
np.max(lr_patches_train)=1.4908749
np.mean(lr_patches_train)=0.46885294
np.std(lr_patches_train)=0.10291

In [7]:
def check_tiles(data_path, index=0, ext="np", s=256):
    # if "np" in ext:
    #     lr_tile = np.load(lr_path)[index][0]
    #     hr_tile = np.load(hr_path)[index][0]
    # elif "tif" in ext:
    #     lr_tile = tifffile.imread(f"{lr_path}").squeeze()
    #     hr_tile = tifffile.imread(f"{hr_path}").squeeze()
    data_path = Path(data_path)

    if "tif" in ext:
        lr_tile = tifffile.imread(f"{next(data_path.glob(f'**/lr/{i}.tif'))}").squeeze()
        hr_tile = tifffile.imread(f"{next(data_path.glob(f'**/hr/{i}.tif'))}").squeeze()

    us = np.array(Image.fromarray(lr_tile).resize((s, s)))

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(us, vmin=hr_tile.min(), vmax=hr_tile.max())
    plt.colorbar()
    plt.subplot(1, 3, 2)
    plt.imshow(hr_tile)
    plt.colorbar()
    plt.subplot(1, 3, 3)
    plt.imshow(hr_tile - us, cmap=cc.cm.CET_D7, vmin=-0.5, vmax=0.5)
    plt.colorbar()


In [4]:
# def ers_to_tifs(survey_path, out_dir=""):
#     out_dir = Path(out_dir)
#     out_dir.mkdir(parents=True, exist_ok=True)

#     ers_arr = np.array(rio.open(survey_path).read(1))
#     tifffile.imsave(out_dir / f"{survey_path.stem}.tif", ers_arr)
#     print(f'Saved {Path(out_dir / f"{survey_path.stem}.tif").absolute()}')


# survey_search = "*4.ers"
# for survey_path in Path("C:/Luke/PhD/Oasis Montaj/ArbSR").glob(survey_search):
#     print(survey_path)
#     # survey_path
#     ers_to_tifs(survey_path, out_dir="PPDRC")

# # Test no loss of information:
# # np.max(tifffile.imread(next(Path(r"C:\Luke\PhD\paper2\SRvey\utils\PPDRC").glob("*.tif"))) - np.array(rio.open(r"C:\Luke\PhD\Oasis Montaj\ArbSR\p681_1.ers").read(1)))


In [12]:
from pathlib import Path
import matplotlib.pyplot as plt
import tifffile
import numpy as np

ims = []
for im in Path(r"C:\Luke\data\Paper 2\PPDRC\lr32\train").glob("**\*.tif"):
    ims.append(tifffile.imread(im))

ims = np.array(ims)[0]


  ims = np.array(ims)[0]


In [10]:
print(ims.max())
print(ims.min())
print(ims.mean())


0.51393044
-0.4547087
0.056754638


In [16]:
survey_search = "**/*"

for survey_path in Path(r"C:\Luke\data\Paper 2\PPDRC\lr32\train").glob(survey_search):
    print(tifffile.imread(next(survey_path.glob("**/*0050.tif"))).max())
    print(tifffile.imread(next(survey_path.glob("**/*0050.tif"))).max())
    print(tifffile.imread(next(survey_path.glob("**/*0200.tif"))).min())
    print(tifffile.imread(next(survey_path.glob("**/*0200.tif"))).min())


StopIteration: 