# ShearLineCNN
> Shear Line Classification using CNN.

## Revision History

| #   | Date       | Action                                           | Modified by        |
|-----|------------|--------------------------------------------------|--------------------|
|     |            |                                                  |                    |
| 030 | 2025-06-01 | Revert to Adadelta                               | rmaniego           |
| 029 | 2025-06-01 | Apply LR Scheduler                               | rmaniego           |
| 028 | 2025-06-01 | Decrease `EarlyStopping` patience                | rmaniego           |
| 027 | 2025-06-01 | Increase validation partition (60:40)            | rmaniego           |
| 026 | 2025-06-01 | Apply stronger regularization                    | rmaniego           |
| 025 | 2025-06-01 | Tune hyperparameters                             | rmaniego           |
| 024 | 2025-06-01 | Expand architecture                              | rmaniego           |
| 023 | 2025-06-01 | Fix `Conv2D->BN->Activation` sequence            | rmaniego           |
| 022 | 2025-06-01 | Fix map incorrect flipping                       | rmaniego           |
| 021 | 2025-06-01 | Reduce dataset dimension                         | rmaniego           |
| 020 | 2025-06-01 | Investigate dataset                              | rmaniego           |
| 019 | 2025-05-30 | Fix differing evaluations                        | rmaniego           |
| 018 | 2025-05-30 | Update test visualizations                       | rmaniego           |
| 017 | 2025-05-27 | Fix `no-shear` dataset                           | rmaniego           |
| 016 | 2025-05-27 | Tune hyperparameters                             | rmaniego           |
| 015 | 2025-05-22 | Decrease initial LR                              | rmaniego           |
| 014 | 2025-05-22 | Change to Tversky loss                           | rmaniego           |
| 013 | 2025-05-21 | Add data augmentation                            | rmaniego           |
| 012 | 2025-05-21 | Add data preprocessing                           | rmaniego           |
| 011 | 2025-05-20 | Migrate to U-Net architecture                    | rmaniego           |
| 010 | 2025-05-16 | Optimize architecture                            | rmaniego           |
| 009 | 2025-05-16 | Improve architecture                             | rmaniego           |
| 008 | 2025-05-15 | Fix model metrics                                | rmaniego           |
| 007 | 2025-05-03 | Fix testing evaluation                           | rmaniego           |
| 006 | 2025-05-03 | Fix dataset loader                               | rmaniego           |
| 005 | 2025-05-03 | Fix segmentation dataset                         | rmaniego           |
| 004 | 2025-04-10 | Fix architecture to match dataset                | rmaniego           |
| 003 | 2025-04-10 | Update architecture base codes                   | rmaniego           |
| 002 | 2025-04-09 | Prepare dataset                                  | rmaniego           |
| 001 | 2025-03-29 | Create GitHub repository                         | rmaniego           |

## Step 1. Mount Google Drive

**Notes:**.
 - This requires GDrive permissions.
 - Update changes in local repository.
 - Re-run cell for every commit changes in the repository.
 - Colab is read only, unless set in GitHub FGPATs

```python
pip install jupyterlab
pip install notebook
jupyter notebook
```

**GitHub Personal Access Tokens (PAT)**
1. Go to `https://github.com/settings/tokens`.
2. On the sidebar, select `Fine-grained tokens`.
3. Fill-up appropriate details, limit read/write access.
4. Copy generated `PAT` to local environment variables.
5. Do the same to Google Colab secrets.
6. Once expired, move the old repo in GDrive to trash.

In [None]:
import os

github_fgpat = None
live_on_colab = False
environment_ready = False

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

try:
    from google.colab import drive, userdata

    drive.mount("/content/drive")

    live_on_colab = True
    github_fgpat = userdata.get("ShearLineCNN")
    print("Running on Google Colaboratory...")
except ImportError:
    print("Running locally...")

## Step 2. Check Colab Compute Engine Backend
**Note:** Execute to verify HW accelerator allocation, use information on manuscript.

HW accelerator availability may vary, so ensure that the session is timed and is connected to expected runtime environment in all iterations. Options include:
1. NVIDIA A100 Tensor Core GPU - high-performance deep learning training (recommended).
2. NVIDIA L4 Tensor Core GPU - optimized for AI inference tasks with high performance and efficiency (preferred during HP fine-tuning).
3. NVIDIA T4 Tensor Core GPU - cost-effective, versatile, and suitable for a variety of tasks.


In [None]:
if live_on_colab:
    gpu_info = !nvidia-smi
    gpu_info = "\n".join(gpu_info)
    if gpu_info.find("failed") >= 0:
        print("Not connected to a GPU")
    else:
        print(gpu_info)

    from psutil import virtual_memory

    ram_gb = virtual_memory().total / 1e9
    print(f"Your runtime has {ram_gb:.1f} gigabytes of available RAM")

    if ram_gb < 20:
        print("Not using a high-RAM runtime")
    else:
        print("You are using a high-RAM runtime!")

## Step 3. Change Working Directory

**Notes:**  
1. Before continuing, make sure you check your GDrive storage usage; cloning on limited storage may impact contents.  
2. Execute cell to ensure the notebook is running under the latest version of project repository.  

In [None]:
if live_on_colab:
    NB = "/content/drive/MyDrive/Colab Notebooks"
    os.makedirs(NB, exist_ok=True)
    os.chdir(NB)

    def update_repo():

        REPO = f"{NB}/ShearLineCNN"
        if not os.path.isdir(REPO):
            !git clone https://{github_fgpat}@github.com/rmaniego/ShearLineCNN.git
            os.chdir(REPO)
            return

        os.chdir(REPO)
        !git reset --hard HEAD
        !git pull origin main

    update_repo()

print(os.getcwd())

## Step 4. Install Dependencies
**Note:** Execute cell everytime the `Google Colab` runtime environment reconnected.

In [None]:
if live_on_colab:
    %pip install -U jupyterlab
    %pip install -U notebook
    %pip install -U opencv-python
    %pip install -U scikit-learn
    %pip install -U scikit-image
    %pip install -U tensorflow
    %pip install -U matplotlib
    %pip install -U seaborn
    %pip install -U cartopy
print("Environment is ready...")

## Step 5: Import the Packages  

import all third party libraries necessary for the ANN model to execute successfully.

In [None]:
import glob
import json
import time
import random
import warnings
from datetime import datetime

warnings.filterwarnings("ignore", category=RuntimeWarning, message="os.fork()")
warnings.filterwarnings("ignore", category=UserWarning, message="Your `PyDataset` class should call")
warnings.filterwarnings("ignore", category=UserWarning, message="warn")

import cv2
import numpy as np
import tensorflow as tf
from skimage.morphology import disk
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.regularizers import l2
from sklearn.model_selection import train_test_split
from scipy.ndimage import binary_dilation, binary_erosion
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Conv2D, Activation, BatchNormalization, MaxPooling2D, Dropout, Conv2DTranspose, Concatenate, Cropping2D
from tensorflow.keras.optimizers import Adadelta, AdamW
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import matplotlib.pyplot as plt
import cartopy.crs as ccrs


gpus = tf.config.list_physical_devices("GPU")

if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
            print(e)
    print("GPU detected. Running on GPU.")
else:
    print("No GPU detected. Running on CPU.")

## Step 6: Load Datasets  

Load and prepare the training and testing datasets.

In [None]:
def load_features(source, target, category):
    basenames, features, labels = [], [], []

    sources = glob.glob(f"{source}/{category}/*.json")
    for i, source_path in enumerate(sources):
        filename = os.path.basename(source_path)

        with open(source_path, "r", encoding="utf-8") as file:
            data1 = np.array(json.load(file))

        target_path = f"{target}/{category}/{filename}"
        with open(target_path, "r", encoding="utf-8") as file:
            data2 = np.array(json.load(file))

        features.append(data1)
        labels.append(data2)
        basenames.append(filename)

    return basenames, np.array(features), np.array(labels)

def visualize_map(i, filename, features, true_mask, predicted_mask=None, fp_flag=False, figsize=(2, 1), extent=(115, 150, 5, 45)):

    label = " (FP)" if fp_flag else ""
    print(f"#{i+1}: {filename}")

    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_extent(extent)
    ax.set_facecolor("#303030")
    ax.coastlines(resolution="10m", linewidth=0.5, color="k")

    ny, nx = features.shape[:2]
    lon0 = np.linspace(extent[0], extent[1], nx)
    lat0 = np.linspace(extent[2], extent[3], ny)
    lon, lat = np.meshgrid(lon0, lat0)

    ax.contourf(lon, lat, features, levels=20, cmap="coolwarm", transform=ccrs.PlateCarree(), zorder=1)

    true_overlay = np.ma.masked_where(true_mask == 0, true_mask)
    ax.imshow(true_overlay, cmap="Greys", alpha=0.8, extent=extent, transform=ccrs.PlateCarree(), zorder=2)

    if predicted_mask is not None:
        pred_overlay = np.ma.masked_where(predicted_mask == 0, predicted_mask)
        ax.imshow(pred_overlay, cmap="Wistia", alpha=0.5, extent=extent, transform=ccrs.PlateCarree(), zorder=3)

    ax.set_title(f"Spatial Map{label}")
    plt.show()
    print("---\n")


#################################################
# Label Dilation (Widening)                     #
#-----------------------------------------------#
# Dilate shear line annotation to help model    #
# converge faster. Since 1px is too thin        #
# for the model to classify easiliy.            #
#################################################

def label_line_widening(labels_array, target_width=5):
    """
    Modifies the line width in a batch of binary segmentation masks in-memory
    to a specified target width using morphological dilation.
    """

    if target_width < 1 or target_width % 2 == 0:
        raise ValueError("target_width must be a positive odd integer (e.g., 1, 3, 5).")

    dilation_radius = (target_width - 1) // 2

    if dilation_radius == 0:
        return labels_array.astype(labels_array.dtype)

    struct_elem = disk(dilation_radius)

    modified_labels = np.zeros_like(labels_array, dtype=labels_array.dtype)
    num_samples = labels_array.shape[0]

    for i in range(num_samples):
        label_2d = labels_array[i].squeeze()

        modified_label_2d = binary_dilation(label_2d, structure=struct_elem).astype(labels_array.dtype)

        if len(labels_array.shape) == 4:
            modified_labels[i] = np.expand_dims(modified_label_2d, axis=-1)
        else:
            modified_labels[i] = modified_label_2d

    return modified_labels


#################################################
# Label Erosion (Shrinking)                     #
#-----------------------------------------------#
# Erode shear line annotation to approximate    #
# original thin lines. Useful for reversing     #
# prior dilation step for analysis or           #
# visualization.                                #
#################################################

def label_line_shrinking(labels_array, target_width=5):
    if target_width < 1 or target_width % 2 == 0:
        raise ValueError("target_width must be a positive odd integer (e.g., 1, 3, 5).")

    erosion_radius = (target_width - 1) // 2

    if erosion_radius == 0:
        return labels_array.astype(labels_array.dtype)

    struct_elem = disk(erosion_radius)

    modified_labels = np.zeros_like(labels_array, dtype=labels_array.dtype)
    num_samples = labels_array.shape[0]

    for i in range(num_samples):
        label_2d = labels_array[i].squeeze()

        modified_label_2d = binary_erosion(label_2d, structure=struct_elem).astype(labels_array.dtype)

        if len(labels_array.shape) == 4:
            modified_labels[i] = np.expand_dims(modified_label_2d, axis=-1)
        else:
            modified_labels[i] = modified_label_2d

    return modified_labels


basenames_list, features_list, labels_list = [], [], []

categories = ["no-shear", "shear"]
segmentation_source = "data/segmentation/source"
segmentation_target = "data/segmentation/target"
for category in categories:
    basenames, features, labels = load_features(segmentation_source, segmentation_target, category)
    if features.size > 0:
        basenames_list.extend(basenames)
        features_list.append(features)
        labels_list.append(labels)

features = np.vstack(features_list)
labels = np.vstack(labels_list)
basenames = np.array(basenames_list)

#################################################
# Perturbation-based Dataset Augmentation       #
#################################################

target_width = 7  # odd number
labels = labels.astype(np.uint8)
labels = label_line_widening(labels, target_width=target_width)

test_ratio = 0.1
target_total_size = 10_000
train_val_target_size = int(target_total_size * (1 - test_ratio))

binary_class_labels = np.array([1 if labels[i].any() else 0 for i in range(len(labels))])

# 90:10 split for testing (using indices) with stratification
total_indices = np.arange(len(basenames))
train_val_indices, test_indices = train_test_split(
    total_indices,
    test_size=test_ratio,
    shuffle=True,
    stratify=binary_class_labels
)

features_train_val = features[train_val_indices]
labels_train_val = labels[train_val_indices]
basenames_train_val = basenames[train_val_indices]

current_train_val_size = len(features_train_val)
if current_train_val_size < train_val_target_size:
    num_augmentations_needed = train_val_target_size - current_train_val_size

    augmented_features_list = []
    augmented_labels_list = []
    augmented_basenames_list = []
    noise_std_dev = 0.2

    idx_to_augment = list(range(current_train_val_size))
    random.shuffle(idx_to_augment)

    aug_count = 0
    while aug_count < num_augmentations_needed:
        if not idx_to_augment:
            idx_to_augment = list(range(current_train_val_size))
            random.shuffle(idx_to_augment)

        original_idx = idx_to_augment.pop(0)

        original_feature = features_train_val[original_idx]
        original_label = labels_train_val[original_idx]
        original_basename = basenames_train_val[original_idx]

        noise = np.random.normal(loc=0.0, scale=noise_std_dev, size=original_feature.shape)
        aug_feature = original_feature + noise
        aug_label = original_label

        augmented_features_list.append(aug_feature)
        augmented_labels_list.append(aug_label)
        augmented_basenames_list.append(f"{original_basename}_aug{aug_count}_noise{noise_std_dev}")

        aug_count += 1
        if aug_count % 100 == 0:
            print(f"Generated {aug_count} augmented samples...")

    augmented_features = np.array(augmented_features_list)
    augmented_labels = np.array(augmented_labels_list)
    augmented_basenames = np.array(augmented_basenames_list)

    features_train_val = np.concatenate((features_train_val, augmented_features), axis=0)
    labels_train_val = np.concatenate((labels_train_val, augmented_labels), axis=0)
    basenames_train_val = np.concatenate((basenames_train_val, augmented_basenames), axis=0)


#################################################
# Dataset Splitting                             #
#################################################

train_val_labels = np.array([1 if labels_train_val[i].any() else 0 for i in range(len(labels_train_val))])

# 80:20 split for train-validation (on train-val set) with stratification
train_val_size = len(features_train_val)
indices_train, indices_val = train_test_split(
    np.arange(train_val_size),
    test_size=0.2,
    shuffle=True,
    stratify=train_val_labels
)

basenames_train = basenames_train_val[indices_train]
features_train = features_train_val[indices_train]
labels_train = labels_train_val[indices_train]

basenames_val = basenames_train_val[indices_val]
features_val = features_train_val[indices_val]
labels_val = labels_train_val[indices_val]

basenames_test = basenames[test_indices]
features_test = features[test_indices]
labels_test = labels[test_indices]

n_train = len(features_train)
n_val = len(features_val)
n_test = len(features_test)

print(f"[Train Samples: {n_train}]")
# for i, (filename, features, true_mask) in enumerate(zip(basenames_train, features_train, labels_train)):
#     visualize_map(i, filename, features, true_mask)

print(f"\n[Validation Samples: {n_val}]")
# for i, (filename, features, true_mask) in enumerate(zip(basenames_val, features_val, labels_val)):
#     visualize_map(i, filename, features, true_mask)

print(f"\n[Test Samples: {n_test}]")
# for i, (filename, features, true_mask) in enumerate(zip(basenames_test, features_test, labels_test)):
#     visualize_map(i, filename, features, true_mask)

n_dataset = n_train + n_val + n_test
print(f"\nTOTAL: {n_dataset}")

print("Dataset ready...")

## Step 7: Define the Architecture  

Define the structure of the convolutional neural network for shear line classification.

In [None]:
ALPHA = 0.95
BETA = 0.05

def tversky_index(y_true, y_pred, alpha=ALPHA, beta=BETA, smooth=K.epsilon()):
    """
    Tversky index for binary segmentation.
    -------------------------------------------------------------------
    Salehi, S. S. M., Erdogmus, D., & Gholipour, A. (2017).
    Tversky loss function for image segmentation
    using 3D fully convolutional deep networks (No. arXiv:1706.05721).
    arXiv. https://doi.org/10.48550/arXiv.1706.05721
    """
    y_true_f = K.cast(K.flatten(y_true), "float32")
    y_pred_f = K.cast(K.flatten(y_pred), "float32")
    tp = K.sum(y_true_f * y_pred_f)
    fp = K.sum((1 - y_true_f) * y_pred_f)
    fn = K.sum(y_true_f * (1 - y_pred_f))
    return (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)

def tversky_loss(y_true, y_pred):
    """Tversky loss (1 - Tversky index)."""
    return 1 - tversky_index(y_true, y_pred)


#################################################
# Custom U-Net Architecture                     #
#-----------------------------------------------#
# Uses skip connections by concatenating        #
# encoder outputs with corresponding decoder    #
# layers to preserve spatial information and    #
# improve segmentation accuracy.                #
#################################################

# Layer declarations
convXa = Conv2D(32, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bnXa = BatchNormalization()
actXa = Activation("relu")
convXb = Conv2D(32, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bnXb = BatchNormalization()
actXb = Activation("relu")
poolX = MaxPooling2D(2, padding="same")
dropoutX = Dropout(0.3)

conv0a = Conv2D(64, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn0a = BatchNormalization()
act0a = Activation("relu")
conv0b = Conv2D(64, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn0b = BatchNormalization()
act0b = Activation("relu")
pool0 = MaxPooling2D(2, padding="same")
dropout0 = Dropout(0.3)

conv1a = Conv2D(128, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn1a = BatchNormalization()
act1a = Activation("relu")
conv1b = Conv2D(128, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn1b = BatchNormalization()
act1b = Activation("relu")
pool1 = MaxPooling2D(2, padding="same")
dropout1 = Dropout(0.3)

conv2a = Conv2D(256, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn2a = BatchNormalization()
act2a = Activation("relu")
conv2b = Conv2D(256, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn2b = BatchNormalization()
act2b = Activation("relu")
pool2 = MaxPooling2D(2, padding="same")
dropout2 = Dropout(0.3)

conv3a = Conv2D(512, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn3a = BatchNormalization()
act3a = Activation("relu")
conv3b = Conv2D(512, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn3b = BatchNormalization()
act3b = Activation("relu")
pool3 = MaxPooling2D(2, padding="same")
dropout3 = Dropout(0.3)

bottleneck1a = Conv2D(1024, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn4a = BatchNormalization()
act4a = Activation("relu")
bottleneck1b = Conv2D(1024, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn4b = BatchNormalization()
act4b = Activation("relu")
bottleneck1c = Conv2D(1024, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn4c = BatchNormalization()
act4c = Activation("relu")
dropout4 = Dropout(0.3)

up1 = Conv2DTranspose(512, 2, strides=2, padding="same")
crop1 = Cropping2D(((1, 0), (1, 0)))
concat1 = Concatenate()
conv4a = Conv2D(512, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn5a = BatchNormalization()
act5a = Activation("relu")
conv4b = Conv2D(512, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn5b = BatchNormalization()
act5b = Activation("relu")
bn5 = BatchNormalization()
dropout5 = Dropout(0.3)

up2 = Conv2DTranspose(256, 2, strides=2, padding="same")
crop2 = Cropping2D(((1, 0), (0, 0)))
concat2 = Concatenate()
conv5a = Conv2D(256, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn6a = BatchNormalization()
act6a = Activation("relu")
conv5b = Conv2D(256, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn6b = BatchNormalization()
act6b = Activation("relu")
dropout6 = Dropout(0.3)

up3 = Conv2DTranspose(128, 2, strides=2, padding="same")
crop3 = Cropping2D(((1, 0), (0, 0)))
concat3 = Concatenate()
conv6a = Conv2D(128, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn7a = BatchNormalization()
act7a = Activation("relu")
conv6b = Conv2D(128, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn7b = BatchNormalization()
act7b = Activation("relu")
dropout7 = Dropout(0.3)

up4 = Conv2DTranspose(64, 2, strides=2, padding="same")
crop4 = Cropping2D(((1, 0), (1, 0)))
concat4 = Concatenate()
conv7a = Conv2D(64, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn8a = BatchNormalization()
act8a = Activation("relu")
conv7b = Conv2D(64, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn8b = BatchNormalization()
act8b = Activation("relu")
dropout8 = Dropout(0.3)

up5 = Conv2DTranspose(32, 2, strides=2, padding="same")
crop5 = Cropping2D(((1, 0), (1, 0)))
concat5 = Concatenate()
conv8a = Conv2D(32, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn9a = BatchNormalization()
act9a = Activation("relu")
conv8b = Conv2D(32, 3, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-03))
bn9b = BatchNormalization()
act9b = Activation("relu")
dropout9 = Dropout(0.3)

final = Conv2D(1, 1, activation="sigmoid")


#################################################
# Forward Pass with concatenation.              #
#################################################

inputs = Input(shape=(161, 141, 1))

outputsX = convXa(inputs)
outputsX = bnXa(outputsX)
outputsX = actXa(outputsX)
outputsX = convXb(outputsX)
outputsX = bnXb(outputsX)
outputsX = actXb(outputsX)
pooledX = poolX(outputsX)
pooledX = dropoutX(pooledX)

outputs0 = conv0a(pooledX)
outputs0 = bn0a(outputs0)
outputs0 = act0a(outputs0)
outputs0 = conv0b(outputs0)
outputs0 = bn0b(outputs0)
outputs0 = act0b(outputs0)
pooled0 = pool0(outputs0)
pooled0 = dropout0(pooled0)

outputs1 = conv1a(pooled0)
outputs1 = bn1a(outputs1)
outputs1 = act1a(outputs1)
outputs1 = conv1b(outputs1)
outputs1 = bn1b(outputs1)
outputs1 = act1b(outputs1)
pooled1 = pool1(outputs1)
pooled1 = dropout1(pooled1)

outputs2 = conv2a(pooled1)
outputs2 = bn2a(outputs2)
outputs2 = act2a(outputs2)
outputs2 = conv2b(outputs2)
outputs2 = bn2b(outputs2)
outputs2 = act2b(outputs2)
pooled2 = pool2(outputs2)
pooled2 = dropout2(pooled2)

outputs3 = conv3a(pooled2)
outputs3 = bn3a(outputs3)
outputs3 = act3a(outputs3)
outputs3 = conv3b(outputs3)
outputs3 = bn3b(outputs3)
outputs3 = act3b(outputs3)
pooled3 = pool3(outputs3)
pooled3 = dropout3(pooled3)

outputs = bottleneck1a(pooled3)
outputs = bn4a(outputs)
outputs = act4a(outputs)
outputs = bottleneck1b(outputs)
outputs = bn4b(outputs)
outputs = act4b(outputs)
outputs = bottleneck1c(outputs)
outputs = bn4c(outputs)
outputs = act4c(outputs)
outputs = dropout4(outputs)

outputs = up1(outputs)
outputs = crop1(outputs)
outputs = concat1([outputs, outputs3])
outputs = conv4a(outputs)
outputs = bn5a(outputs)
outputs = act5a(outputs)
outputs = conv4b(outputs)
outputs = bn5b(outputs)
outputs = act5b(outputs)
outputs = dropout5(outputs)

outputs = up2(outputs)
outputs = crop2(outputs)
outputs = concat2([outputs, outputs2])
outputs = conv5a(outputs)
outputs = bn6a(outputs)
outputs = act6a(outputs)
outputs = conv5b(outputs)
outputs = bn6b(outputs)
outputs = act6b(outputs)
outputs = dropout6(outputs)

outputs = up3(outputs)
outputs = crop3(outputs)
outputs = concat3([outputs, outputs1])
outputs = conv6a(outputs)
outputs = bn7a(outputs)
outputs = act7a(outputs)
outputs = conv6b(outputs)
outputs = bn7b(outputs)
outputs = act7b(outputs)
outputs = dropout7(outputs)

outputs = up4(outputs)
outputs = crop4(outputs)
outputs = concat4([outputs, outputs0])
outputs = conv7a(outputs)
outputs = bn8a(outputs)
outputs = act8a(outputs)
outputs = conv7b(outputs)
outputs = bn8b(outputs)
outputs = act8b(outputs)
outputs = dropout8(outputs)

outputs = up5(outputs)
outputs = crop5(outputs)
outputs = concat5([outputs, outputsX])
outputs = conv8a(outputs)
outputs = bn9a(outputs)
outputs = act9a(outputs)
outputs = conv8b(outputs)
outputs = bn9b(outputs)
outputs = act9b(outputs)
outputs = dropout9(outputs)

outputs = final(outputs)

model = Model(inputs, outputs)

# Compile Model
optimizer = Adadelta(learning_rate=1.0, rho=0.95)  # ρ = 0.9-0.95
model.compile(optimizer=optimizer, loss=tversky_loss, metrics=[tversky_index])
model.summary()


MODELS = "models"
ARCHITECTURE = "UNet"
model_path = f"{MODELS}/shearline.{ARCHITECTURE}.png"

plot_model(model, to_file=model_path, show_shapes=False)
Image(model_path, width=300)

# Notes: 2025-05-15
# The issue here is that BCE is not reliable when there is a class imbalance: where 0s dominates 1s across spatial maps.
# After a series of experimentations, Dice shown inability to suppress false positives.
# We found a more appropriate option, which is the Tversky Loss.
# It is generalization of Dice and allows flexible penalization of FP and FN.
# So, this loss can further refine the predictions by explicitly reducing FP.
# Dice is just a special case of Tversky when α = β = 0.5; Tversky = TP / (TP + α*FP + β*FN).

### Step 8: Train the Model  

> Feed the training-val dataset to the compiled CNN model.  

**Note:** Re-compile model again before running this step.

In [None]:
EPOCHS = 1000
BATCH_SIZE = 32

ANALYSIS = "analysis"
DATASET = "data"
TEST = "test"

os.makedirs(MODELS, exist_ok=True)
os.makedirs(ANALYSIS, exist_ok=True)

training_timestamp = int(time.time())

early_stopping = EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=True)
history = model.fit(
    features_train,
    labels_train,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(features_val, labels_val),
    callbacks=[early_stopping]
)

training_duration = (int(time.time()) - training_timestamp) / 60
print(f"Training completed in {training_duration:.2f} minutes.")


fullpath = f"{MODELS}/shearline.{ARCHITECTURE}_{training_timestamp}.keras"
model.save(fullpath)

with open(f"{ANALYSIS}/metrics_{training_timestamp}.json", "w") as f:
    json.dump({
        "loss": history.history["loss"],
        "tversky_index": history.history["tversky_index"],
        "val_loss": history.history["val_loss"],
        "val_tversky_index": history.history["val_tversky_index"]
    }, f, indent=4)

print(f"Model training complete and saved to '{fullpath}'")
print(f"Training and validation metrics saved to '{ANALYSIS}/metrics_{training_timestamp}.json'")

## Step 9: Generate Training Analysis  

**Metrics Definitions**
* Loss is computed based on how far each prediction is from the ground truth, specifically using Tversky Loss.
* Tversky Index is a generalized Dice Coefficient and is a type of accuracy, which allows flexible penalization of FP and FN--best for imbalanced datasets.  

In [None]:
import seaborn as sns


with open(f"{ANALYSIS}/metrics_{training_timestamp}.json", "r") as f:
    metrics = json.load(f)

epochs = range(1, len(metrics["loss"]) + 1)

plt.figure(figsize=(8, 6))
sns.lineplot(x=epochs, y=metrics["loss"], label="Training Tversky Loss", color="blue")
sns.lineplot(x=epochs, y=metrics["val_loss"], label="Validation Tversky Loss", color="orange")
plt.title("Tversky Loss vs. Epochs")
plt.xlabel("Epochs")
plt.ylabel("Tversky Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"{ANALYSIS}/tversky_loss_plot_{training_timestamp}.png")
plt.show()

plt.figure(figsize=(8, 6))
sns.lineplot(x=epochs, y=metrics["tversky_index"], label="Training Tversky Index", color="green")
sns.lineplot(x=epochs, y=metrics["val_tversky_index"], label="Validation Tversky Index", color="red")
plt.title("Tversky Index vs. Epochs")
plt.xlabel("Epochs")
plt.ylabel("Tversky Index")
plt.legend()
plt.grid(True)
plt.savefig(f"{ANALYSIS}/tversky_index_plot_{training_timestamp}.png")
plt.show()

print(f"\nPlots saved to {ANALYSIS}")

## Step 10: Test the Model

In [None]:
start_time = time.time()
results = model.predict(features_test, verbose=1)
prediction_duration = time.time() - start_time
image_prediction_time = prediction_duration / len(features_test)

predictions = (results.squeeze(-1) > 0.5).astype("int32")
predictions_flat = predictions.flatten()
labels_test_flat = labels_test.flatten()

print(f"Total prediction time: {prediction_duration:.4f} seconds")
print(f"Time per spatial map: {image_prediction_time:.4f} seconds")

## Step 11: Display the Results

In [None]:
false_positives = []
for i, (filename, true_mask, predicted_mask, features) in enumerate(zip(basenames_test, labels_test, predictions, features_test)):
    fp_flag = np.all(true_mask == 0) and np.any(predicted_mask != 0)
    if fp_flag:
        false_positives.append(filename)
    visualize_map(i, filename, features, true_mask, predicted_mask, fp_flag, figsize=(6, 4))

total_fp = len(false_positives)
print(f"\n[False Positives ({total_fp})]")
for filename in false_positives:
    print(filename)
    

## Step 12: Region-based Segmentation Evaluation

> Evaluate the spatial map accuracy of predicted shear line binary mask using segmentation metrics.

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

tverskys = []

for t_mask, p_mask in zip(labels_test, predictions):
    t_mask = t_mask.reshape(161, 141)
    p_mask = p_mask.reshape(161, 141)

    tverskys.append(tversky_index(t_mask, p_mask))

tversky_score = np.mean(tverskys)

true_all = np.concatenate([t.flatten() for t in labels_test])
pred_all = np.concatenate([p.flatten() for p in predictions])

cm = confusion_matrix(true_all, pred_all, labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["No Shear Line", "Shear Line"])

fig, ax = plt.subplots(figsize=(6, 6))
disp.plot(cmap="magma", ax=ax)
plt.title(f"Tversky Index={tversky_score:.4f}")
plt.savefig(f"{ANALYSIS}/confusion_matrix_{training_timestamp}.png")
plt.show()

# Landis & Koch (1977) scale (based on Kappa Scale)
# with overlapping bounds to ensure proper interpretation
interpretations = {
    (-1.0, 0.00): "Poor agreement",
    (0.00, 0.21): "Slight agreement",
    (0.21, 0.41): "Fair agreement",
    (0.41, 0.61): "Moderate agreement",
    (0.61, 0.81): "Substantial agreement",
    (0.81, 1.01): "Almost perfect agreement"
}

for score_range, interpretation in interpretations.items():
    if score_range[0] <= tversky_score < score_range[1]:
        print(f"Tversky Interpretation: {interpretation}")
        break

# Step 13: Duplicate Notebook  

**Note:** Manually save first before duplicating the notebook.

In [None]:
VALIDATIONS = "validations"
os.makedirs(VALIDATIONS, exist_ok=True)

filename = "ShearLineCNN.ipynb"
with open(filename, "r", encoding="utf-8") as src:
    contents = src.read()
    checkpoint = f"{VALIDATIONS}/{filename}".replace(".ipynb", f"_{training_timestamp}.ipynb")
    with open(checkpoint, "w", encoding="utf-8") as dest:
        dest.write(contents)
        print(f"Checkpoint was created at '{checkpoint}'.")

> End of code.