## Data Preparation for Training

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import rasterio as rio
from scipy import ndimage

In [None]:
in_file = 'KENNETH_20250109_post_fire.tif'
out_file = 'KENNETH_20250109_interpolated.tif'
mask_file = "KENNETH_20250109_mask.tif"

In [None]:
with rio.open(in_file) as img:
    data = img.read().astype(np.float32)
    profile = img.profile
    data[data == img.nodata] = np.nan

filled_data = np.zeros_like(data)

for idx in range(data.shape[0]):
    band = data[idx,:, :].copy()
    mask = np.isnan(band)  # True is nan
    if np.any(mask):
        # Replace nan with nearest non-nan values
        indices = ndimage.distance_transform_edt(mask, return_distances=False, return_indices=True)
        band = band[tuple(indices)]
    filled_data[idx] = band 

profile.update(dtype=rio.float32, nodata=np.nan)

with rio.open(out_file, 'w', **profile) as out:
    out.write(filled_data)

In [None]:
data.shape
filled_data.shape
np.isnan(filled_data)

#### Band normlization

https://github.com/allenai/satlas/blob/main/Normalization.md

In [None]:
with rio.open(out_file) as f:
    bands = f.read()
    norm_data = np.zeros_like(bands)
    for i in range(bands.shape[0]):
        print(f"Max in origin band {i}: {np.max(bands[i])}, Min in origin band {i}: {np.min(bands[i])}")
        # Perform per-band normalization
        norm_data[i] = (bands[i] - np.min(bands[i])) / (np.max(bands[i]) - np.min(bands[i]))
        print(f"Max in normalized band {i}: {np.max(norm_data[i])}, Min in normalized band {i}: {np.min(norm_data[i])}")


In [None]:
with rio.open(mask_file) as m:
    mask = m.read()
mask.squeeze().shape

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(10, 8))
ax[0, 0].imshow(norm_data[0], )
ax[0, 0].set_title("Blue")
ax[0, 1].imshow(norm_data[1])
ax[0, 1].set_title("Green")
ax[0, 2].imshow(norm_data[2])
ax[0, 2].set_title("Red")
ax[1, 0].imshow(norm_data[3])
ax[1, 0].set_title("Near Infrared")
ax[1, 1].imshow(norm_data[4])
ax[1, 1].set_title("Short Wave Infrared")
ax[1, 2].imshow(mask[0])
ax[1, 2].set_title("Mask")    # 0 is the background
plt.show()