In [9]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, UpSampling2D, concatenate
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt


from get_file_matches import get_tif_file_matches

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
import rasterio
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split

In [10]:
def load_raster_data(file_path):
    with rasterio.open(file_path) as src:
        data = src.read()
    return np.moveaxis(data, 0, -1)  # Move channel axis to the end

def load_data(matches, resize_shape=(256, 256)):
    X, Y = [], []
    for stacked_tif, ground_truth_tif in matches.items():
        # Load stacked TIF and ground truth mask
        stacked = load_raster_data(stacked_tif)
        ground_truth = load_raster_data(ground_truth_tif)[..., 0]  # Use the first channel
        
        # Resize to consistent shape
        stacked = tf.image.resize(stacked, resize_shape).numpy()
        ground_truth = tf.image.resize(ground_truth[..., np.newaxis], resize_shape).numpy()
        
        X.append(stacked)
        Y.append(ground_truth)
    
    return np.array(X), np.array(Y)

In [11]:

def visualize_prediction(model, X, Y):
    """
    Visualize predictions for the segmentation task.

    Parameters:
    - model: Trained segmentation model.
    - X: Input data (stacked raster).
    - Y: Ground truth masks.
    """
    pred = model.predict(X)
    for i in range(len(X)):
        plt.figure(figsize=(15, 5))
        
        # Plot the input (first three bands as RGB)
        plt.subplot(1, 3, 1)
        plt.title("Input (First 3 Bands)")
        plt.imshow(X[i, :, :, :3])  # Display the first 3 bands as an RGB image
        
        # Plot the ground truth
        plt.subplot(1, 3, 2)
        plt.title("Ground Truth")
        plt.imshow(Y[i].squeeze(), cmap='gray')  # Ground truth mask
        
        # Plot the prediction
        plt.subplot(1, 3, 3)
        plt.title("Prediction")
        plt.imshow(pred[i].squeeze(), cmap='gray')  # Predicted mask
        
        plt.show()

def evaluate_predictions(model, X, Y, threshold=0.5):
    """
    Evaluate predictions of the segmentation model using IoU, Dice, Accuracy, Precision, and Recall.
    
    Parameters:
    - model: Trained segmentation model.
    - X: Input images (stacked raster).
    - Y: Ground truth masks.
    - threshold: Threshold to binarize predicted mask (default = 0.5).
    
    Returns:
    - metrics: Dictionary containing evaluation metrics (IoU, Dice, Accuracy, Precision, Recall).
    """
    metrics = {'IoU': [], 'Dice': [], 'Accuracy': [], 'Precision': [], 'Recall': []}
    
    # Generate predictions for X
    predictions = model.predict(X)
    predictions = (predictions > threshold).astype(np.uint8)  # Binarize predictions
    
    for i in range(len(X)):
        y_true = Y[i].squeeze().flatten()
        y_pred = predictions[i].squeeze().flatten()
        
        iou = jaccard_score(y_true, y_pred)
        dice = f1_score(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=1)
        recall = recall_score(y_true, y_pred, zero_division=1)
        
        metrics['IoU'].append(iou)
        metrics['Dice'].append(dice)
        metrics['Accuracy'].append(accuracy)
        metrics['Precision'].append(precision)
        metrics['Recall'].append(recall)
    
    # Calculate the mean of all metrics
    mean_metrics = {key: np.mean(values) for key, values in metrics.items()}
    print("\n==== Model Evaluation Metrics ====")
    for metric, value in mean_metrics.items():
        print(f"{metric}: {value:.4f}")
    
    return metrics, mean_metrics

In [18]:
def build_deeplabv3(input_shape):
    inputs = Input(input_shape)

    # Replace the first convolution layer to handle 4 input channels
    x = Conv2D(64, (7, 7), strides=(2, 2), padding="same", name='conv1_conv_modified')(inputs)
    x = BatchNormalization(name='bn_conv1_modified')(x)
    x = Activation("relu")(x)

    # Use ResNet50 as the backbone, starting from the second layer
    base_model = ResNet50(
        weights=None,  # We cannot use pre-trained weights due to the custom input layer
        include_top=False,
        input_tensor=x
    )

    # Extract specific layers for ASPP
    layer_names = ['conv4_block6_2_relu', 'conv5_block3_out']
    base_layers = [base_model.get_layer(name).output for name in layer_names]

    # ASPP-like module
    b0 = Conv2D(256, (1, 1), padding="same", activation=None)(base_layers[1])
    b0 = BatchNormalization()(b0)
    b0 = Activation("relu")(b0)

    # Upsampling and concatenation
    x = UpSampling2D((4, 4), interpolation="bilinear")(b0)

    # Upsample base_layers[0] to match x
    base_layer_resized = UpSampling2D(size=(2, 2), interpolation="bilinear")(base_layers[0])
    x = concatenate([x, base_layer_resized], axis=-1)
    x = Conv2D(128, (3, 3), padding="same", activation=None)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    # Final segmentation output
    x = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(x)

    # Upsample the output to match the input size
    outputs = UpSampling2D((input_shape[0] // 16, input_shape[1] // 16), interpolation="bilinear")(x)

    return Model(inputs, outputs)


In [None]:

root_dir = "data/Tschernitz"
folder1 = "output"
folder2 = "ground_truth_masks/tree_masks"

# Match stacked TIF and ground truth mask
matches = get_tif_file_matches(root_dir, folder1, folder2, contains1="stacked", contains2="merged")

# Load Data and Train
X, Y = load_data(matches)
X_train, X_val, Y_train, Y_val = train_test_split(X / 255.0, (Y > 0).astype(np.float32), test_size=0.2, random_state=42)

# Build and compile DeepLabV3+ model
deeplab_model = build_deeplabv3(X_train.shape[1:])
deeplab_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = deeplab_model.fit(
    X_train, Y_train,
    validation_data=(X_val, Y_val),
    batch_size=8,
    epochs=25
)

# Visualize training and evaluate predictions
plt.plot(history.history['loss'], label='Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

visualize_prediction(deeplab_model, X_val, Y_val)
metrics_deeplab, mean_metrics_deeplab = evaluate_predictions(deeplab_model, X_val, Y_val, threshold=0.5)
print(mean_metrics_deeplab)





In [None]:
from albumentations import Compose, RandomRotate90, HorizontalFlip, ShiftScaleRotate, RandomBrightnessContrast, GaussNoise
# Targeted augmentation pipeline
def get_augmentation_pipeline():
    return Compose([
        RandomRotate90(p=0.5),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(shift_limit=0.01, scale_limit=0.05, rotate_limit=15, p=0.5),
        RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    ])

# Apply augmentation
def augment_dataset(X, Y, num_versions=3):
    augmentation_pipeline = get_augmentation_pipeline()
    X_aug, Y_aug = [], []
    for i in range(len(X)):
        for _ in range(num_versions):
            augmented = augmentation_pipeline(image=X[i], mask=Y[i])
            X_aug.append(augmented['image'])
            Y_aug.append(augmented['mask'])
    return np.array(X_aug), np.array(Y_aug)


# Apply targeted augmentation
X_train_aug, Y_train_aug = augment_dataset(X_train, Y_train)

# Combine original and augmented data (optional, if you want to keep originals)
X_train_final = np.concatenate((X_train, X_train_aug), axis=0)
Y_train_final = np.concatenate((Y_train, Y_train_aug), axis=0)

# Train the model with the augmented data
input_shape = X_train_final.shape[1:]
unet_model_refined_aug = build_unet(input_shape)
unet_model_refined_aug.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

history_refined_aug = unet_model_refined_aug.fit(
    X_train_final, Y_train_final,
    validation_data=(X_val, Y_val),
    batch_size=8,
    epochs=25
)
#  Visualize Results
plt.plot(history_refined_aug.history['loss'], label='Loss')
plt.plot(history_refined_aug.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

# Visualize refined augmentation predictions
visualize_prediction(unet_model_refined_aug, X_val, Y_val)
metrics_refined_aug, mean_metrics_refined_aug = evaluate_predictions(unet_model_refined_aug, X_val, Y_val, threshold=0.5)