In [None]:
import rasterio
import numpy as np
import cv2
import os
from tqdm import tqdm
from PIL import Image
import tensorflow as tf
from pathlib import Path
import geopandas as gpd
from rasterio.features import rasterize


# === Set paths ===
base_dir = Path(r"D:\\UNETTest\\real data")
raster_path = base_dir / "raster.tif"
vector_path = base_dir / "vector.shp"
tiles_dir = base_dir / "tiles"
images_dir = tiles_dir / "images"
masks_dir = tiles_dir / "masks"
predictions_dir = base_dir / "predictions"  # Directory to save GeoTIFF files
tile_size = 256  # Tile size

# === Create output directories ===
images_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)
predictions_dir.mkdir(parents=True, exist_ok=True)

# === Load image data ===
print("Loading image data...")
with rasterio.open(raster_path) as src:
    image = src.read(1)  # Read the first band, returns (H, W)
    transform = src.transform  # Get the original image transform
    crs = src.crs  # Get the coordinate reference system
    out_shape = (src.height, src.width)

# === Rasterize the mask ===
print("Rasterizing mask...")
gdf = gpd.read_file(vector_path)  # Load the vector mask data
mask = rasterize(
    [(geom, 1) for geom in gdf.geometry],
    out_shape=out_shape,
    transform=transform,
    fill=0,
    dtype=np.uint8
)

# === Slice data and save each tile's coordinate system ===
print("Starting to slice...")
h, w = mask.shape
count = 0

# Open the original image and get the coordinate system information
with rasterio.open(raster_path) as src:
    crs = src.crs  # Get the coordinate reference system
    transform = src.transform  # Get the affine transform
    for i in tqdm(range(0, h, tile_size)):
        for j in range(0, w, tile_size):
            img_tile = image[i:i+tile_size, j:j+tile_size]
            mask_tile = mask[i:i+tile_size, j:j+tile_size]

            if img_tile.shape == (tile_size, tile_size):
                # Convert float32 image to uint8 (assuming values are in the range 0~1 or 0~255)
                if img_tile.dtype != np.uint8:
                    img_tile = (img_tile * 255).clip(0, 255).astype(np.uint8)

                img_tile_pil = Image.fromarray(img_tile)
                mask_tile_pil = Image.fromarray((mask_tile * 255).astype(np.uint8))

                tile_name = f"tile_{i}_{j}.tif"
                
                # Calculate the new affine transform based on the slice position
                new_transform = transform * transform.translation(j, i)
                print(new_transform, 'Transform when slicing')

                # Save the slice and retain the coordinate reference system (CRS) and affine transform
                with rasterio.open(images_dir / tile_name, 'w', driver='GTiff', 
                                   height=tile_size, width=tile_size, count=1, 
                                   dtype=np.uint8, crs=crs, transform=new_transform) as dst:
                    dst.write(img_tile, 1)
                
                tile_name_mask = f"mask_{i}_{j}.tif"
                with rasterio.open(masks_dir / tile_name_mask, 'w', driver='GTiff', 
                                   height=tile_size, width=tile_size, count=1, 
                                   dtype=np.uint8, crs=crs, transform=new_transform) as dst:
                    dst.write(mask_tile, 1)

                count += 1

print(f"Slicing complete, a total of {count} image and mask tile pairs were generated.")


In [None]:
import os
import numpy as np
import rasterio
import cv2
from pathlib import Path
from sklearn.model_selection import train_test_split
from rasterio.transform import from_origin
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model

# === Berk, please change the data path to your corresponding path ===
IMAGE_DIR = r"D:\UNETTest\real data\tiles\images"
MASK_DIR = r"D:\UNETTest\real data\tiles\masks"
VAL_IMAGE_OUT = r"D:\UNETTest\real data\val_export\images"
VAL_MASK_OUT = r"D:\UNETTest\real data\val_export\masks"

# Create validation directories
os.makedirs(VAL_IMAGE_OUT, exist_ok=True)
os.makedirs(VAL_MASK_OUT, exist_ok=True)

IMG_HEIGHT, IMG_WIDTH = 256, 256

# === Load images and masks, and return filenames ===
def load_data(image_dir, mask_dir, tile_names):
    images, masks, filenames = [], [], []
    print("📥 Loading data...")
    for filename in tile_names:
        if filename.endswith(".tif"):
            img_path = os.path.join(image_dir, filename)
            mask_path = os.path.join(mask_dir, filename.replace("tile", "mask"))

            with rasterio.open(img_path) as src:
                img = src.read(1)

            with rasterio.open(mask_path) as src:
                mask = src.read(1)

            img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
            img = (img - np.min(img)) / (np.max(img) - np.min(img))
            img = np.expand_dims(img, axis=-1)

            mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT))
            mask = (mask > 0).astype(np.float32)
            mask = np.expand_dims(mask, axis=-1)

            images.append(img)
            masks.append(mask)
            filenames.append(filename)

    return np.array(images), np.array(masks), filenames

# === Save validation data, retaining spatial information ===
def save_val_data(X_val, y_val, filenames_val, original_image_dir, original_mask_dir):
    print("💾 Saving validation images and masks (retaining spatial reference)...")
    for i in range(len(X_val)):
        image = (X_val[i].squeeze() * 255).astype(np.uint8)
        mask = (y_val[i].squeeze() * 255).astype(np.uint8)

        filename_img = filenames_val[i]
        filename_mask = filename_img.replace("tile", "mask")

        # Extract transform and CRS from the original slice
        with rasterio.open(os.path.join(original_image_dir, filename_img)) as src_img:
            transform = src_img.transform
            crs = src_img.crs

        img_out_path = os.path.join(VAL_IMAGE_OUT, filename_img)
        mask_out_path = os.path.join(VAL_MASK_OUT, filename_mask)

        with rasterio.open(
            img_out_path, "w", driver="GTiff", height=image.shape[0],
            width=image.shape[1], count=1, dtype=image.dtype,
            crs=crs, transform=transform
        ) as dst:
            dst.write(image, 1)

        with rasterio.open(
            mask_out_path, "w", driver="GTiff", height=mask.shape[0],
            width=mask.shape[1], count=1, dtype=mask.dtype,
            crs=crs, transform=transform
        ) as dst:
            dst.write(mask, 1)

    print(f"✅ {len(X_val)} validation images and masks (with spatial reference) saved.")

# === Main logic ===
# Get the filenames of all tiles
tile_names = [f for f in os.listdir(IMAGE_DIR) if f.endswith(".tif")]

# Randomly choose 70 tiles
sampled_tiles = np.random.choice(tile_names, 70, replace=False)

# Print selected filenames
print("Selected tile filenames:", sampled_tiles)

# Split into training and validation sets
X, y, filenames = load_data(IMAGE_DIR, MASK_DIR, sampled_tiles)
X_train, X_val, y_train, y_val, filenames_train, filenames_val = train_test_split(
    X, y, filenames, test_size=0.2, random_state=42
)

# Print statistics
print("✅ Number of training images:", len(X_train))
print("✅ Number of validation images:", len(X_val))
print("📊 Number of active pixels in training mask:", np.sum(y_train))
print("📊 Number of active pixels in validation mask:", np.sum(y_val))

# Save validation data
save_val_data(X_val, y_val, filenames_val, IMAGE_DIR, MASK_DIR)


In [None]:
import random
import matplotlib.pyplot as plt

# Randomly select an image index
rand_index = random.randint(0, len(X) - 1)

# Extract the image and mask
rand_img = X[rand_index][:, :, 0]
rand_mask = y[rand_index][:, :, 0]

print(f"🎲 Randomly selected index: {rand_index}")
print(f"🟡 Number of non-zero pixels in the mask: {np.sum(rand_mask)}")

# Visualization
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.imshow(rand_img, cmap='gray')
plt.title("📷 Original Image (DEM)")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(rand_mask, cmap='gray', vmin=0, vmax=1)
plt.title("🟢 Mask (Ground Truth)")
plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
def build_unet(input_size=(IMG_HEIGHT, IMG_WIDTH, 1)):
    inputs = Input(input_size)

    def conv_block(x, f):
        x = Conv2D(f, 3, activation='relu', padding='same')(x)
        x = Conv2D(f, 3, activation='relu', padding='same')(x)
        return x

    c1 = conv_block(inputs, 64)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = conv_block(p1, 128)
    p2 = MaxPooling2D((2, 2))(c2)
    c3 = conv_block(p2, 256)
    p3 = MaxPooling2D((2, 2))(c3)
    c4 = conv_block(p3, 512)

    u5 = UpSampling2D((2, 2))(c4)
    u5 = concatenate([u5, c3])
    c5 = conv_block(u5, 256)

    u6 = UpSampling2D((2, 2))(c5)
    u6 = concatenate([u6, c2])
    c6 = conv_block(u6, 128)

    u7 = UpSampling2D((2, 2))(c6)
    u7 = concatenate([u7, c1])
    c7 = conv_block(u7, 64)

    outputs = Conv2D(1, 1, activation='sigmoid')(c7)

    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
    return model


In [None]:
import rasterio
import numpy as np
import cv2
import os
import tensorflow as tf
import matplotlib.pyplot as plt

class VisualCallback(tf.keras.callbacks.Callback):
    def __init__(self, val_data, save_path, template_raster_path, max_epochs, tile_size, filenames):
        self.val_data = val_data
        self.save_path = save_path
        self.template_raster_path = template_raster_path
        self.max_epochs = max_epochs  
        self.tile_size = tile_size
        self.filenames = filenames

    def on_epoch_end(self, epoch, logs=None):
        if epoch + 1 == self.max_epochs:
            print(f"\n📈 Epoch {epoch+1} | Loss: {logs['loss']:.4f} | Val_Loss: {logs['val_loss']:.4f}")
            
            for idx, (img, mask) in enumerate(zip(self.val_data[0], self.val_data[1])):
                pred = self.model.predict(np.expand_dims(img, axis=0))[0, :, :, 0]
                pred = tf.nn.sigmoid(pred).numpy()
                print(pred, 'Predicted values')
                # I set the threshold manually as 0.587
                pred_bin = np.where(pred >= 0.587, 255, 0).astype(np.uint8)

                raster_path = os.path.join(os.path.dirname(self.template_raster_path), self.filenames[idx])
                with rasterio.open(raster_path) as src:
                    profile = src.profile
                    transform = src.transform
                    crs = src.crs
                    original_shape = src.read(1).shape
                    print(f"✅ Processing: {self.filenames[idx]}, Original image size: {original_shape}")

                pred_resized = cv2.resize(pred_bin, (original_shape[1], original_shape[0]), interpolation=cv2.INTER_NEAREST)

                profile.pop('nodata', None)
                profile.update(dtype=rasterio.uint8, count=1, compress='lzw')
                profile['transform'] = transform
                profile['crs'] = crs

                out_path = os.path.join(self.save_path, f"prediction_{self.filenames[idx]}_epoch_{epoch+1}.tif")
                with rasterio.open(out_path, 'w', **profile) as dst:
                    dst.write(pred_resized, 1)

                print(f"💾 Saved prediction {idx+1} to: {out_path}")

                plt.figure(figsize=(18, 6))
                plt.subplot(1, 3, 1)
                plt.imshow(img, cmap='gray')
                plt.title(f"Original Image {idx+1}")
                plt.axis('off')

                plt.subplot(1, 3, 2)
                plt.imshow(mask, cmap='gray')
                plt.title(f"Ground Truth Mask {idx+1}")
                plt.axis('off')

                plt.subplot(1, 3, 3)
                plt.imshow(pred_resized, cmap='gray', vmin=0, vmax=255)
                plt.title(f"Prediction Result {idx+1}")
                plt.axis('off')

                plt.show()


In [None]:
# Save paths
PREDICT_SAVE_PATH = r"D:\UNETTest\real data\predictions"
TEMPLATE_RASTER_PATH = os.path.join(VAL_IMAGE_OUT, os.listdir(VAL_IMAGE_OUT)[0])
MODEL_SAVE_DIR = r"D:\UNETTest\real data\models"

os.makedirs(PREDICT_SAVE_PATH, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# Custom model save callback (one per epoch)
class SaveEveryEpoch(tf.keras.callbacks.Callback):
    def __init__(self, save_dir):
        self.save_dir = save_dir

    def on_epoch_end(self, epoch, logs=None):
        model_path = os.path.join(self.save_dir, f"unet_epoch_{epoch+1:02d}.keras")
        self.model.save(model_path)
        print(f"💾 Model saved: {model_path}")

# Start model training
model = build_unet()
print("✅ Model structure:")
model.summary()

filenames = [f for f in os.listdir(VAL_IMAGE_OUT) if f.endswith('.tif')]  # Get all tif filenames

# Assuming a total of 50 epochs
# Assuming X_val and y_val are the validation data
# PREDICT_SAVE_PATH and TEMPLATE_RASTER_PATH are the paths you provided
visual_cb = VisualCallback(
    val_data=(X_val, y_val),  # Validation data
    save_path=PREDICT_SAVE_PATH,  # Path to save prediction results, the actual path you provided
    template_raster_path=TEMPLATE_RASTER_PATH,  # Path to the original image for spatial information, the actual path you provided
    max_epochs=20,  # Maximum number of epochs
    tile_size=256,  # Tile size
    filenames=filenames
)

# Use the callback during model training
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=20,  # Total number of epochs
    batch_size=16,
    callbacks=[visual_cb]  # Pass the callback to the training
)
