In [None]:
import numpy as np
import rasterio
from tensorflow.keras.models import load_model

# ============================
# Load Keras model
# ============================
model = load_model("my_unet_model.kreas")  
patch_size = 256  # <-- size your UNet expects (height=width)

# ============================
# Load GeoTIFF satellite image
# ============================
input_tif = "input_image.tif"  # <-- change to your input image path
with rasterio.open(input_tif) as src:
    img = src.read()  # shape: (bands, height, width)
    profile = src.profile

# ============================
# Preprocess image
# ============================
# Convert to (H, W, bands)
img = np.transpose(img, (1, 2, 0))
H, W, C = img.shape

# Normalize (depends on training)
img = img.astype("float32") / 255.0

# Pad image so dimensions are multiples of patch_size
pad_h = (patch_size - H % patch_size) % patch_size
pad_w = (patch_size - W % patch_size) % patch_size
img_padded = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")

new_H, new_W, _ = img_padded.shape
e
# ============================
# Function to extract patches
# ============================
def get_patches(image, patch_size):
    patches = []
    coords = []
    for i in range(0, image.shape[0], patch_size):
        for j in range(0, image.shape[1], patch_size):
            patch = image[i:i+patch_size, j:j+patch_size, :]
            patches.append(patch)
            coords.append((i, j))
    return np.array(patches), coords

patches, coords = get_patches(img_padded, patch_size)

# ============================
# Run prediction on patches
# ============================
all_preds = []
for patch in patches:
    inp = np.expand_dims(patch, axis=0)  # (1, h, w, c)
    pred = modl.predict(inp, verbose=0)
    pred = np.squeeze(pred)  # (h, w) if single channel
    all_preds.append(pred)

# ============================
# Reconstruct full image
# ============================
pred_full = np.zeros((new_H, new_W), dtype=np.float32)

for pred, (i, j) in zip(all_preds, coords):
    pred_full[i:i+patch_size, j:j+patch_size] = pred

# Crop back to original size
pred_full = pred_full[:H, :W]

# ============================
# Save prediction as GeoTIFF
# ============================
profile.update(
    dtype=rasterio.float32,
    count=1,
    height=H,
    width=W,
    compress="lzw"
)

output_tif = "prediction_output.tif"
with rasterio.open(output_tif, "w", **profile) as dst:
    dst.write(pred_full, 1)

print(f"Prediction saved to {output_tif}")
