In [None]:
import os
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterio.windows import Window
import geopandas as gpd
from shapely.ops import unary_union
from sklearn.model_selection import train_test_split, class_weight
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from tensorflow.keras.utils import to_categorical
from glob import glob
from random import sample, choice, uniform
from skimage.transform import rotate
from tqdm import tqdm

In [None]:
# -----------------------------------------------------------------------------
# 1. USER PARAMETERS
# -----------------------------------------------------------------------------
PATCH_SIZE    = 32
HALF_PATCH    = PATCH_SIZE // 2
CLASS_INDICES = {'Sand': 0, 'SAV': 1}
NUM_CLASSES   = len(CLASS_INDICES)
BATCH_SIZE    = 32
EPOCHS        = 5
LEARNING_RATE = 1e-4

TRAIN_MOSAIC  = r"C:\Users\00097030\Git\WA-coast-SAV\Data\Mosaic\Mosaic_clip.tif"
PTS_SHP       = r"C:\Users\00097030\Git\WA-coast-SAV\Data\BOSS\BOSS_update.shp"
EXTENT_SHP    = r"C:\Users\00097030\Git\WA-coast-SAV\Data\extent\generalised_extent_update.shp"
PATCH_DIR     = r"C:\Users\00097030\Git\WA-coast-SAV\Data\image_patches"
PREDICT_IMG   = r"C:\Users\00097030\Git\WA-coast-SAV\Data\Mosaic\kndavi_50JLL_clip.tif"
MODEL_DIR     = r"C:\Users\00097030\Git\WA-coast-SAV\Data\models"

In [None]:
# -----------------------------------------------------------------------------
# 2. EXTRACT 1-BAND PATCHES AROUND Labeled Points
# -----------------------------------------------------------------------------
# - Clip the mosaic to the study extent
extent_geom = unary_union(gpd.read_file(EXTENT_SHP).geometry)
with rasterio.open(TRAIN_MOSAIC) as src:
    clipped, transform = mask(src, [extent_geom], crop=True, indexes=[1])
    base_profile = src.profile

# - Update metadata for fixed-size, single-band patches
patch_profile = base_profile.copy()
patch_profile.update({
    'height': PATCH_SIZE,
    'width': PATCH_SIZE,
    'count': 1,
    'dtype': clipped.dtype
})

# - Split points into training/validation
points = gpd.read_file(PTS_SHP)
points = points[points['Class'].isin(CLASS_INDICES)]
train_pts, val_pts = train_test_split(
    points, test_size=0.2, stratify=points['Class'], random_state=0
)

# - Create output folders
for split, subset in (("train", train_pts), ("val", val_pts)):
    for cls in CLASS_INDICES:
        os.makedirs(os.path.join(PATCH_DIR, split, cls), exist_ok=True)

    # - Extract and save each patch
    for idx, row in subset.iterrows():
        x, y = row.geometry.x, row.geometry.y
        col, row_i = map(int, (~transform) * (x, y))
        window = Window(col-HALF_PATCH, row_i-HALF_PATCH, PATCH_SIZE, PATCH_SIZE)

        # skip if patch would go outside the raster
        if (window.col_off < 0 or window.row_off < 0 or
            window.col_off + window.width  > clipped.shape[2] or
            window.row_off + window.height > clipped.shape[1]):
            continue

        patch = clipped[0,
                        int(window.row_off):int(window.row_off + window.height),
                        int(window.col_off):int(window.col_off + window.width)]

        out_meta = patch_profile.copy()
        out_meta.update({'transform': rasterio.windows.transform(window, transform)})
        out_fp = os.path.join(PATCH_DIR, split, row['Class'], f"{split}_{idx}.tif")

        with rasterio.open(out_fp, 'w', **out_meta) as dst:
            dst.write(patch[np.newaxis, ...])

In [None]:
# -----------------------------------------------------------------------------
# 3.MEAN & STD 
# -----------------------------------------------------------------------------
train_files = glob(os.path.join(PATCH_DIR, "train", "*", "*.tif"))
sum_, sum_sq, cnt = 0.0, 0.0, 0

for fp in train_files:
    with rasterio.open(fp) as src:
        band = src.read(1, masked=True)
    values = band.compressed().astype(np.float64)
    sum_   += values.sum()
    sum_sq += (values**2).sum()
    cnt    += values.size

MEAN = sum_ / cnt
VAR  = sum_sq / cnt - MEAN**2
STD  = float(np.sqrt(max(VAR, 0)))

In [None]:
# -----------------------------------------------------------------------------
# . COMPUTE CLASS WEIGHTS TO HANDLE IMBALANCE
# -----------------------------------------------------------------------------
#labels = np.array([
#    CLASS_INDICES[os.path.basename(os.path.dirname(f))]
#    for f in train_files
#])
#cw = class_weight.compute_class_weight("balanced",
#                                       classes=np.unique(labels),
#                                       y=labels)
#CLASS_WEIGHTS = {i: w for i, w in enumerate(cw)}

Class weights: {0: 0.75, 1: 1.5}


In [None]:
# -----------------------------------------------------------------------------
# 4. SIMPLE GENERATOR 
# -----------------------------------------------------------------------------
def generator(files, augment=False):
    while True:
        batch = sample(files, BATCH_SIZE)
        X, Y = [], []
        for fp in batch:
            with rasterio.open(fp) as src:
                img = src.read(1, masked=True).filled(MEAN).astype(np.float32)
            img = (img - MEAN) / (STD + 1e-6)
            if augment:
                if choice([True, False]):
                    img = np.fliplr(img)
                if choice([True, False]):
                    img = np.flipud(img)
                angle = uniform(-45, 45)
                img = rotate(img, angle, mode='reflect', preserve_range=True)
            cls = os.path.basename(os.path.dirname(fp))
            X.append(img[..., np.newaxis])
            Y.append(CLASS_INDICES[cls])
        yield np.stack(X), to_categorical(Y, NUM_CLASSES)

train_gen = generator(train_files, augment=True)
val_files  = glob(os.path.join(PATCH_DIR, "val", "*", "*.tif"))
val_gen    = generator(val_files, augment=False)

In [None]:
# -----------------------------------------------------------------------------
# 5. MODEL TRAINING
# -----------------------------------------------------------------------------
input_tensor = Input((PATCH_SIZE, PATCH_SIZE, 1))
base_model   = DenseNet201(include_top=False, weights=None, input_tensor=input_tensor)
x            = GlobalAveragePooling2D()(base_model.output)
output       = Dense(NUM_CLASSES, activation='softmax')(x)
model        = Model(base_model.input, output)

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['categorical_accuracy']
)

os.makedirs(MODEL_DIR, exist_ok=True)
checkpoint_cb = ModelCheckpoint(
    os.path.join(MODEL_DIR, "dense_sav_{epoch:02d}-{val_categorical_accuracy:.3f}.h5"),
    monitor='val_categorical_accuracy',
    save_best_only=True,
    verbose=1
)
earlystop_cb  = EarlyStopping(
    monitor='val_categorical_accuracy',
    patience=10,
    restore_best_weights=True
)
tensorboard_cb = TensorBoard(log_dir='./logs')

steps_per_epoch   = len(train_files) // BATCH_SIZE
validation_steps  = len(val_files)   // BATCH_SIZE

model.fit(
    train_gen,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_gen,
    validation_steps=validation_steps,
    epochs=EPOCHS,
    callbacks=[checkpoint_cb, earlystop_cb, tensorboard_cb]
)


Epoch 1/5
Epoch 1: val_categorical_accuracy improved from -inf to 0.67188, saving model to C:\Users\00097030\Git\WA-coast-SAV\Data\models\dense_sav_01-0.672.h5
Epoch 2/5
Epoch 2: val_categorical_accuracy improved from 0.67188 to 0.75000, saving model to C:\Users\00097030\Git\WA-coast-SAV\Data\models\dense_sav_02-0.750.h5
Epoch 3/5
Epoch 3: val_categorical_accuracy did not improve from 0.75000
Epoch 4/5
Epoch 4: val_categorical_accuracy did not improve from 0.75000
Epoch 5/5
Epoch 5: val_categorical_accuracy improved from 0.75000 to 0.76562, saving model to C:\Users\00097030\Git\WA-coast-SAV\Data\models\dense_sav_05-0.766.h5


<keras.callbacks.History at 0x2b08b8ea740>

In [None]:
# -----------------------------------------------------------------------------
# 6. SINGLE-BAND pred
# -----------------------------------------------------------------------------
with rasterio.open(PREDICT_IMG) as src:
    img   = src.read(1, masked=True).filled(MEAN).astype(np.float32)
    meta  = src.profile.copy()

H0,W0 = img.shape
pad   = HALF
img_p = np.pad(img,((pad,pad),(pad,pad)),mode='reflect')
img_p = (img_p-MEAN)/(STD+1e-6)

prob = np.zeros((H0,W0,NUM_CLASSES),dtype=np.float32)
buf  = np.zeros((W0,PATCH_SIZE,PATCH_SIZE,1),dtype=np.float32)

for i in tqdm(range(pad,pad+H0),desc="Inferring rows"):
    for j in range(pad,pad+W0):
        buf[j-pad,...,0] = img_p[i-pad:i+pad,j-pad:j+pad]
    pr = model.predict(buf, batch_size=256, verbose=0)
    prob[i-pad,:,:] = pr

label    = np.argmax(prob,axis=-1).astype(np.uint8)
prob_max = prob.max(axis=-1).astype(np.float32)


Inferring rows: 100%|██████████| 544/544 [02:24<00:00,  3.76it/s]


In [23]:
out_label_fp  = os.path.join("Data","Mosaic","classified_label_2.tif")
out_prob_fp  = os.path.join("Data","Mosaic","classified_prob_2.tif")
# write outputs
meta.update(count=1,dtype='uint8')
with rasterio.open(out_label_fp,'w',**meta) as dst: dst.write(label,1)
meta.update(dtype='float32')
with rasterio.open(out_prob_fp,'w',**meta) as dst: dst.write(prob_max,1)