# Convert raster to tiled ML-ready data

ers and tif files are common for geoscience. We need a quick way to tile these and save them to our standard numpy/folder tile dataset.

This notebook defines functions which:
1. convert an ers/tif to a pytorch tensor
1. pads any nan areas with reflection padding 
1. unfolds the tensor in each direction (read torch .fold() / .unfold())
1. stacks those into a "batch" of tiles
1. generates a selection of indices for validation and training data
1. saves each to a seperate train/val folder
1. Generates some QA/QC figures

In [2]:
# Initialise the valid mask dictionary in a cell above this one... so we can rerun the cell below.
# valid_d = None

import logging
import joblib
from pathlib import Path

import colorcet as cc
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
import tifffile
from PIL import Image

logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s:%(asctime)s: %(message)s",
)
logger = logging.getLogger(__name__)
logger.info("##### Starting new log")


class Norm:
    def __init__(self, t_file_path):
        """Load a pre-fit sklearn Transformer to normalise input array"""

        self.transformer = joblib.load(t_file_path)
        logger.info(f"Using {t_file_path} to transform your input array")

    def transform(self, arr):
        og_shape = arr.shape
        arr = arr.flatten().reshape(-1, 1)
        arr = self.transformer.transform(arr.astype(np.float64))
        return arr.reshape(og_shape)


survey_search = "p*"
survey_dir = Path("C:/Luke/data/multiscale_TMI/ers_surveys")
tile_dir = Path("C:/Luke/data/multiscale_TMI/tiles/")
norm = Norm("../utils/AUS_MAGPMAP_v7_ONSHORE_QuantileTransformer.joblib")
scales = sorted([1, 2, 3, 4])


def img_to_tiles(
    raster_path: Path,
    scale: int,
    hr_s: int = 240,
    norm=False,
    nan_val=-999_999,
    nan_fill_val: float = None,
    valid_mask=None,
    out_prefix: str = "",
    out_ext="npy",
    single_output_folder: Path = None,
    lcm=12,  # lowest common mutiple of all scales
    pad_mode="constant",
    method="replace",
):

    """
    Specify a single raster of scale n, in a group of several scales.
    This will tile and save them to numpy arrays or image tiles ready to
    load in pytorch.

    We assume input raster of any shape, square tiles (h=w), and integer
    scale factors.

    We include lcm so as to determine a hr image size (multiple of lcm)
    that can remain constant for all scales. For example, an lcm of 12
    (scales 1,2,3,4) would ensure a hr size[0] of e.g. 110 is padded to
    120. The resulting LR sizes for each scale would become 60, 40, and
    30. Else, Hr would be padded to e.g. 110, 111, 116, and potentially
    have a different tile count each time.
    (Or maybe I'm way overthinking it!).
    This is repeated but to ensure an integer span of tiles in each dimension.

    """

    def _pad(t: torch.tensor):
        """Pad such that each hr/lr pair has an equal number of integer sized tiles"""
        pad_h = t.shape[1] + lcm - (t.shape[1] % lcm)
        pad_w = t.shape[2] + lcm - (t.shape[2] % lcm)
        pad_h = pad_h + s - (pad_h % s)
        pad_w = pad_w + s - (pad_w % s)
        pad_h -= t.shape[1]  # .pad() is number of pixels to add...
        pad_w -= t.shape[2]  # ... we calculated target total pixels

        logging.debug(
            f"Padded input: {t.shape[2]}+{pad_w} x {t.shape[1]}+{pad_h} "
            f"to {t.shape[2] + pad_w} x {t.shape[1] + pad_h} "
        )
        return torch.nn.functional.pad(
            t, (0, pad_w, 0, pad_h), mode="constant", value=float("nan")
        )

    def _tile(t: torch.tensor):
        """Tile 1 raster into many smaller arrays"""
        tpr = t.shape[1] // s  # tiles per row
        tpc = t.shape[0] // s  # tiles per column
        tiles = t.unfold(0, s, s).unfold(1, s, s)
        return tiles.contiguous().view(tpr * tpc, -1, s, s)

    def _handle_nans(
        tiles: torch.tensor,
        valid_mask: np.ndarray = None,
        method="remove",
        limit: float = 0.0,
    ):
        """Remove or replace NaNs in tiles containing more than limit

        ML struggles with NaNs. The options are to replace them
        with a constant, or to remove any and all tiles with nans
        (across all matching tiles across scales).

        By default, we replace ALL (limit=0). Code is in place to set a
        percentage [0,1] of each tile to consider replacement instead of
        removal. This should only be calculated on... the scale with the
        most NaN values (as a result of gridding process). You deal with
        it if you need to. Check the git history. Maybe it's strictly HR?
        Pro-tip, it aint #TODO #YOLO. Calculate a mask on each scale and collate.

        We store the valid_mask calculated on the smallest scale (hr)
        and reuse it. #TODO calculate valid mask on all scales and do an
        element-wise OR op, THEN use the mask on all scales.

        """


        if method == "replace":
            tiles = torch.nan_to_num(tiles, nan=nan_fill_val).numpy()
            logger.info(
                f"Changed NaNs in {'some'} tiles to {nan_fill_val}"
            )
        else:
            raise NotImplementedError()
            ### Each scale may have a slightly different valid_mask. We want to ensure
            # We drop all invalid (nan) values in tiles, common between all scales.
            # So we calculate a logical_or on this scales mask, and the previous valid
            # mask result (which should slowly accumulate all invalid tiles)
            # # Calculate percentage of each tile that is NaN valued & threshold
            # new_valid_mask = torch.count_nonzero(tiles.isnan(), dim=(2, 3)) / (
            #     tiles.shape[2] * tiles.shape[3]
            # )
            # new_valid_mask = (new_valid_mask <= limit).to(bool)

            # if len(valid_mask) == 0: # When we are processing this survey for the first time
            #     valid_mask = np.zeros_like(new_valid_mask).astype(bool)

            # new_valid_mask = np.logical_or(valid_mask, new_valid_mask)
            # tiles = tiles[np.where(new_valid_mask)]
            # logger.info(
            #     f"Dropped {sum(~new_valid_mask)} mostly NaN tiles (> {limit*100}% nan)"
            # )

        return tiles #, new_valid_mask

    def _write_files(tiles: np.ndarray):
        """Save to appropriate directory structure"""

        tile_dir = single_output_folder or raster_path.parent / out_ext
        file_name = f"{scale}-0/{out_prefix}{raster_path.stem}"

        if np.isnan(tiles).any() or np.isinf(tiles).any():
            bad = []
            for i, tile in enumerate(tiles):
                if np.isnan(tile).any():
                    bad.append([file_name, i])
            print(f"Your data {bad} contained NaN valued cells.")

        (tile_dir / f"{scale}-0").mkdir(parents=True, exist_ok=True)

        if "npy" in out_ext:
            np.save(tile_dir / f"{file_name}.{out_ext}", tiles)
        elif "tif" in out_ext:
            for i, tile in enumerate(tiles):
                tifffile.imsave(tile_dir / f"{file_name}_{i}.{out_ext}", tile)
        logger.info(
            f"Tiles min/max/mean: {tiles.min():0.2f}/{tiles.max():0.2f}/{tiles.mean():0.2f}"
        )
        logger.info(f"Saved {len(tiles)} tiles to {tile_dir.absolute()}")

    logging.info(f"Began processing {raster_path}")

    if raster_path.suffix == ".tif":
        raster = tifffile.imread(raster_path)
    elif raster_path.suffix == ".ers":
        raster = np.array(rio.open(raster_path).read(1))
    raster[raster == nan_val] = float("nan")

    t = torch.as_tensor(norm(raster), dtype=torch.float32).unsqueeze(0)
    s = hr_s // scale
    t = _pad(t).squeeze()
    tiles = _tile(t)
    tiles = _handle_nans(tiles, valid_mask, method=method) # tiles, new_valid_mask
    _write_files(tiles.astype(np.float32))

    return None #new_valid_mask


# if valid_d == None:
#     valid_d = {}  # Store the calculated valid mask for each survey

for survey_path in survey_dir.glob(survey_search):
    if not survey_path.is_dir():
        continue

    # if valid_d.get(survey_path.name) == None:
    #     valid_d[survey_path.name] = []

    for i, scale in enumerate(scales):
        img_to_tiles( # new_valid_mask =
            raster_path=next(survey_path.glob(f"*{scale:d}.ers")),
            scale=scale,
            hr_s=128,
            norm=norm.transform,
            nan_val=-999_999,
            nan_fill_val=0,
            # valid_mask=valid_d[survey_path.name],
            out_ext="tif",  # output extension
            single_output_folder=tile_dir,
            lcm=np.lcm.reduce(scales),
            method="replace",
        )

        # if sum(new_valid_mask) > sum(valid_d[survey_path.name]):
        #     valid_d[survey_path.name] = new_valid_mask
        #     logger.critical(f"Calculated a new valid mask! Rerun this cell")

    # valid_d[survey_path.name] = new_valid_mask


INFO:2022-04-29 12:04:22,596: ##### Starting new log
INFO:2022-04-29 12:04:22,599: Using ../utils/AUS_MAGPMAP_v7_ONSHORE_QuantileTransformer.joblib to transform your input array
INFO:2022-04-29 12:04:22,602: Began processing C:\Luke\data\multiscale_TMI\ers_surveys\p1505\p1505_1.ers
INFO:2022-04-29 12:04:23,663: Changed NaNs in some tiles to 0
INFO:2022-04-29 12:04:24,431: Tiles min/max/mean: 0.00/2.64/0.59
INFO:2022-04-29 12:04:24,432: Saved 405 tiles to C:\Luke\data\multiscale_TMI\tiles
INFO:2022-04-29 12:04:24,444: Began processing C:\Luke\data\multiscale_TMI\ers_surveys\p1505\p1505_2.ers
INFO:2022-04-29 12:04:24,789: Changed NaNs in some tiles to 0
INFO:2022-04-29 12:04:25,457: Tiles min/max/mean: 0.00/2.64/0.60
INFO:2022-04-29 12:04:25,458: Saved 405 tiles to C:\Luke\data\multiscale_TMI\tiles
INFO:2022-04-29 12:04:25,463: Began processing C:\Luke\data\multiscale_TMI\ers_surveys\p1505\p1505_3.ers
INFO:2022-04-29 12:04:25,645: Changed NaNs in some tiles to 0
INFO:2022-04-29 12:04:26,

In [None]:
## ArbSR

# ArbSR requires a specific dataset layout (see the matlab version of this script)
# ArbSR uses a single HR file and n, m scale downsamplings.
# We use a set of pre-gridded files, split into n directories of m scale.
# This script organises these directories as per the expectations of ArbSR.
# ArbSr uses np.arange(1.5, 4.5, 0.5), we have np.arange(2.0, 5.0, 1.0) == [2,3,4]


def process_dir(
    tile_dir: Path,
    out_dir: Path = None,
    val_pct: float = 0.10,  # 10% Val, 85% Train, 5% test
    tst_pct: float = 0.05,
):
    """Based on dir name, process as n scale, for scale = "dir_name_n" """

    rng = np.random.default_rng(seed=21)
    assert tile_dir.exists(), f"Error, {tile_dir.absolute().as_posix()} not found!"

    def _split_indices():
        indices = {}
        tot_num = len(list((tile_dir / "1-0").iterdir()))
        val_num = int(np.round(val_pct * tot_num))  # len(lr_tiles)
        tst_num = int(np.round(tst_pct * tot_num))
        val_indices = sorted(rng.choice(tot_num, size=val_num, replace=False))
        indices["test"] = sorted(rng.choice(tot_num, size=tst_num, replace=False))
        # Drop any indices from val if they were also selected in test:
        indices["val"] = [i for i in val_indices if i not in indices["test"]]
        # Train indices are all remaining indices not in either val or test:
        indices["train"] = [
            i
            for i in range(tot_num)
            if (i not in indices["val"] and i not in indices["test"])
        ]

        logging.info(
            f"{tot_num} tiles, Split: {tot_num-val_num-tst_num}/{val_num}/{tst_num}"
        )
        logging.debug(f'\n{indices["train"]=}\n{indices["val"]=}\n{indices["test"]=}')

        return indices

    def _rearrange_files(scale_dir: Path, indices: dict):
        scale = float(scale_dir.stem.replace("-", "."))
        files = np.array(sorted(list(scale_dir.iterdir())))

        for dset in indices.keys():  # train, val, test
            out_path = tile_dir / out_dir / dset
            if scale == 1:
                out_path = out_path / "HR"
            else:
                out_path = out_path / "LR" / f"X{scale:.2f}_X{scale:.2f}"
            out_path.mkdir(exist_ok=True, parents=True)

            for i, f in enumerate(files[indices[dset]]):
                # print(f"{f} would be renamed to {out_path}\{i:05d}.tif")
                f.rename(out_path / f"{i:05d}.tif")

        print(f"{tile_dir / out_dir} processed and output to {(out_path).absolute()}")

    indices = _split_indices()
    for scale_dir in tile_dir.iterdir():
        # print(scale_dir.absolute())
        assert scale_dir.is_dir(), "Unexpected files found"
        if tile_dir / out_dir == scale_dir:
            continue
        if "old_" in scale_dir.name:
            continue
        _rearrange_files(scale_dir, indices)


process_dir(tile_dir, out_dir="processed")


In [None]:
# So you have line data from a geophysical survey. You've decimated the lines
# and gridded it at several specific scale factors, e.g. remove 2nd, 3rd, 4th lines
# You now want to turn this survey into individual tiles for ML. So that each tile
# covers the same extent, you need to tile them at dimensions relative to their
# scale factor. The dimensions are therefore a decimal (or fractional) scale smaller
# than the original 1x scale grid.
# You may note that 256/3 is not a pleasant number for a discrete count of pixels.
# So we use 240, which goes to 120, 80, and 60 pixels per dimension for each of
#             1,                 2,  3, and  4 times scale, etc.

# An alternative would be to interpolate, but idk how that would affect ArbSR...
# Its easier in the image world, because it's just bicubic downsample on the fly.


In [None]:
# def check_tiles(data_path, index=0, ext="np", s=256):
#     # if "np" in ext:
#     #     lr_tile = np.load(lr_path)[index][0]
#     #     hr_tile = np.load(hr_path)[index][0]
#     # elif "tif" in ext:
#     #     lr_tile = tifffile.imread(f"{lr_path}").squeeze()
#     #     hr_tile = tifffile.imread(f"{hr_path}").squeeze()
#     data_path = Path(data_path)

#     if "tif" in ext:
#         lr_tile = tifffile.imread(f"{next(data_path.glob(f'**/lr/{i}.tif'))}").squeeze()
#         hr_tile = tifffile.imread(f"{next(data_path.glob(f'**/hr/{i}.tif'))}").squeeze()

#     us = np.array(Image.fromarray(lr_tile).resize((s, s)))

#     plt.figure(figsize=(20, 10))
#     plt.subplot(1, 3, 1)
#     plt.imshow(us, vmin=hr_tile.min(), vmax=hr_tile.max())
#     plt.colorbar()
#     plt.subplot(1, 3, 2)
#     plt.imshow(hr_tile)
#     plt.colorbar()
#     plt.subplot(1, 3, 3)
#     plt.imshow(hr_tile - us, cmap=cc.cm.CET_D7, vmin=-0.5, vmax=0.5)
#     plt.colorbar()
