<a href="https://colab.research.google.com/github/ShaneRLos/Conv-LSTM-SAR-XAI/blob/main/IEEE_2024_Conv_LSTM_XAI_SAR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load Datasets

In [None]:
import xarray as xr

# Load altimetry dataset from Google Drive
normalized_altimetry_dataset = xr.open_dataset("/content/drive/MyDrive/Altimetry_Normalized/Normalized_Altimetry_Merged_V3.nc")

# Print first few lines :)
print(normalized_altimetry_dataset.head())

In [None]:
import pandas as pd

# Load the normalized merged buoy dataset from Google Drive
merged_df = pd.read_csv('/content/drive/MyDrive/buoy-M3_year_long/merged_normalized_buoy_data_V3.csv')

# Corrected approach without using 'inplace=True'
merged_df['WaveHeight'] = merged_df['WaveHeight'].interpolate(method='linear')

# Remove rows where 'WindSpeed' is NaN directly within the same DataFrame
merged_df.dropna(subset=['WindSpeed'], inplace=True)

# Display the first few rows of the merged dataset and its shape
merged_df_head = merged_df.head()
merged_df_shape = merged_df.shape

merged_df_head, merged_df_shape

# SAR Image Preprocessing

In [None]:
# FILTER, NORMAL, RESIZE, STACK AND LOAD

import rasterio
from scipy.ndimage import uniform_filter
import numpy as np
import os
from concurrent.futures import ProcessPoolExecutor
import cv2

def lee_filter(img, size):
    img_mean = uniform_filter(img, size)
    img_sqr_mean = uniform_filter(img**2, size)
    img_variance = img_sqr_mean - img_mean**2
    overall_variance = np.var(img)
    img_weights = img_variance / (img_variance + overall_variance)
    return img_mean + img_weights * (img - img_mean)

def read_and_preprocess_sar_image(file_path, target_size=(128, 128)):
    with rasterio.open(file_path) as src:
        sar_image = src.read(1, masked=True)  # Read the first band as a masked array

        # Apply the Lee filter
        filtered_image = lee_filter(sar_image.data, size=3)  # Size of the filter window

        # Mask out the no-data values again after filtering
        filtered_image = np.ma.array(filtered_image, mask=sar_image.mask, fill_value=np.nan)

        # Check for NaN values after filtering
        if np.isnan(filtered_image).any():
            print(f"NaN values found in the filtered image: {file_path}")
            return None  # Skip saving this image

        # Normalize the filtered image
        min_val = filtered_image.min()
        max_val = filtered_image.max()
        normalized_image = (filtered_image - min_val) / (max_val - min_val)

        # Resize the image using OpenCV
        resized_image = cv2.resize(normalized_image.filled(), target_size, interpolation=cv2.INTER_AREA)

        return resized_image

def save_processed_image(file_path, processed_image, output_folder):
    output_filename = os.path.basename(file_path).replace('.tif', '_processed_resized.tif')
    output_path = os.path.join(output_folder, output_filename)
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    with rasterio.open(file_path) as src:
        profile = src.profile
        profile.update(dtype=rasterio.float32, height=processed_image.shape[0], width=processed_image.shape[1], count=1, compress='lzw')
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(processed_image.astype(rasterio.float32), 1)

def process_file(file_path, output_folder):
    if os.path.exists(file_path):
        processed_image = read_and_preprocess_sar_image(file_path)
        if processed_image is not None:  # Only save if there are no NaN values
            save_processed_image(file_path, processed_image, output_folder)
    else:
        print(f"File not found: {file_path}")

# Set input and output folders
input_folder = '/content/drive/MyDrive/ content drive MyDrive Final'
output_folder = '/content/drive/MyDrive/Processed_Images_10k_V1'

# Get all .tif files that are neither processed nor interpolated
sar_files = [
    os.path.join(input_folder, f)
    for f in os.listdir(input_folder)
    if f.endswith('.tif') and not (f.endswith('_processed.tif') or f.endswith('_interpolated.tif'))
]

# Process each SAR file in parallel
with ProcessPoolExecutor() as executor:
    executor.map(process_file, sar_files, [output_folder] * len(sar_files))

# LOAD PROCESSED SAR IMAGES INTO NUMPY ARRAYS

def load_image_rasterio(image_path):
    try:
        with rasterio.open(image_path) as src:
            # Read the first band
            img = src.read(1)
            img = np.expand_dims(img, axis=-1)  # Add channel dimension for grayscale
        return img
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

# List all processed and resized files
image_paths = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith('_processed_resized.tif')]

# Test loading one of the images
test_img_path = image_paths[0] if image_paths else None
test_img = load_image_rasterio(test_img_path) if test_img_path else None
print("Is the image loaded?", test_img is not None)

# If the test image is successful, load all images
if test_img is not None:
    # Load all resized images if they're not None, and convert them to NumPy array
    sar_files = [load_image_rasterio(path) for path in image_paths]
    sar_files = np.stack([img for img in sar_files if img is not None])

    print("Number of SAR images:", len(sar_files))


# Timestamps

## Altimetry Timestamps

In [None]:
import pandas as pd

# Specify start and end dates for the range
start_date = '2015-05-01'
end_date = '2023-12-31'

# Generate the date range at monthly frequency
date_range = pd.date_range(start=start_date, end=end_date, freq='MS')

# Convert the date range to a list of tuples (start_date, end_date) for each month
date_ranges = [(date.strftime('%Y-%m-%dT00:00:00Z'), (date + pd.offsets.MonthEnd(1)).strftime('%Y-%m-%dT00:00:00Z')) for date in date_range]

# For debugging: print the first few date ranges
for start, end in date_ranges[:5]:
    print(f"('{start}', '{end}')")


In [None]:
import pandas as pd

# Define the function to select altimetry data for a given timestamp
def select_altimetry_data_for_timestamp(altimetry_dataset, timestamp, tolerance=0):
    start_date = timestamp - pd.Timedelta(seconds=tolerance)
    end_date = timestamp + pd.Timedelta(seconds=tolerance)

    for time in altimetry_dataset['time']:
        if start_date <= time <= end_date:
            # Select altimetry data within the specified date range
            time_selection = altimetry_dataset.sel(time=time)
            return time_selection

    return None

# Print some information about the altimetry dataset for inspection
print("Altitude Dataset Information:")
print(normalized_altimetry_dataset)

## Buoy Data Timestamps

In [None]:
# Buoy Data Filtering

import pandas as pd

# Convert the 'time' column in the buoy data to datetime format
merged_df['time'] = pd.to_datetime(merged_df['time'])

# Prepare an empty DataFrame to store filtered buoy data
filtered_buoy_data = pd.DataFrame()

# Loop through each date range and filter the data
for start_date, end_date in date_ranges:
    # Convert start and end times to datetime
    start_date = pd.to_datetime(start_date)
    end_date = pd.to_datetime(end_date)
    temp_filtered_data = merged_df[(merged_df['time'] >= start_date) & (merged_df['time'] <= end_date)]
    filtered_buoy_data = pd.concat([filtered_buoy_data, temp_filtered_data])

# Reset the index of the concatenated DataFrame
filtered_buoy_data.reset_index(drop=True, inplace=True)

# Display the first few rows of the filtered data and its shape
filtered_buoy_data_head = filtered_buoy_data.head()
filtered_buoy_data_shape = filtered_buoy_data.shape

filtered_buoy_data_head, filtered_buoy_data_shape



# Alignment & Sequences

## Sequences

In [None]:
# Create Sequences Function

# Function to create overlapping sequences from combined image and mask data
def create_sequences(data, sequence_length=6):
    sequences = []
    for start in range(len(data) - sequence_length + 1):
        seq = data[start:start + sequence_length]
        sequences.append(seq)
    return np.array(sequences)

# Assuming 'images' is your array of SAR images
sequence_length = 6  # This is your desired sequence length

# Create sequences from the interpolated data
sequences = create_sequences(images, sequence_length)

print(f"Created {len(sequences)} sequences of shape {sequences.shape}")

## Alignment

In [None]:
# Alignment Sequences

import os
from datetime import datetime
import re
import numpy as np
import pandas as pd
import xarray as xr

# Function to extract dates from filenames
def extract_date_from_filename(filename):
    match = re.search(r'\d{8}', filename)
    if match:
        date_str = match.group()
        try:
            return datetime.strptime(date_str, '%Y%m%d')
        except ValueError as e:
            raise ValueError(f"Error parsing date from {filename}: {e}")
    else:
        raise ValueError(f"No valid date found in filename: {filename}")

# Directory path where your SAR images are stored
directory_path = '/content/drive/MyDrive/Processed_Images_10k_V1'

# Automatically generate the list of filenames in the directory
sar_image_filenames = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.tif')]

# Extract dates from all filenames
sequence_dates = [extract_date_from_filename(name) for name in sar_image_filenames]

# Assuming you want to work with sequence end dates as before
N = 6  # Length of your sequences
sequence_end_dates = sequence_dates[N-1:]  # Include every date from the Nth image onwards as an end date

# Convert sequence_end_dates to numpy.datetime64
sequence_end_dates_np = np.array([np.datetime64(date) for date in sequence_end_dates])

print(f"Adjusted Sequence End Dates (Potential Sequences): {len(sequence_end_dates_np)}")

# Function to align data with sequence dates using a tolerance window
def align_data_with_tolerance(sequence_dates, data, data_label, tolerance=pd.Timedelta('1 hour')):
    aligned_data = []
    for date in sequence_dates:
        # Find data points within the tolerance window
        mask = (data['time'] >= (date - tolerance)) & (data['time'] <= (date + tolerance))
        closest_points = data.loc[mask]
        if not closest_points.empty:
            # Choose the closest timestamp within the window
            closest_row = closest_points.iloc[(closest_points['time'] - date).abs().argmin()]
            aligned_data.append(closest_row[data_label])
        else:
            aligned_data.append(np.nan)
    return np.array(aligned_data)

# Convert the 'time' column in merged_df to datetime format and make timezone naive
merged_df['time'] = pd.to_datetime(merged_df['time'], utc=True).dt.tz_localize(None)

# Align wave heights
aligned_wave_heights = align_data_with_tolerance(sequence_end_dates_np, merged_df, 'WaveHeight')

# Align wind speeds
aligned_wind_speeds = align_data_with_tolerance(sequence_end_dates_np, merged_df, 'WindSpeed')

# Ensure altimetry dataset time column is in datetime64 format and timezone naive
normalized_altimetry_dataset['time'] = pd.to_datetime(normalized_altimetry_dataset['time'].values).astype('datetime64[ns]')
if hasattr(normalized_altimetry_dataset['time'], 'tz'):
    normalized_altimetry_dataset['time'] = normalized_altimetry_dataset['time'].dt.tz_localize(None)

# Verify the conversion
print("\nAltimetry Dataset Time Samples (after conversion):")
print(normalized_altimetry_dataset['time'].values[:5])
print(type(normalized_altimetry_dataset['time'].values[0]))

# Align altimetry data with debugging
def align_altimetry_data(sequence_dates, altimetry_data, tolerance=np.timedelta64(1, 'h')):
    aligned_altimetry = []
    for date in sequence_dates:
        # Ensure date is numpy.datetime64
        date_np = np.datetime64(date)

        # Debugging: Print types and values
        print(f"Date: {date_np}, Type: {type(date_np)}")
        print(f"Altimetry Time Values: {altimetry_data.time.values[:5]}, Type: {type(altimetry_data.time.values[0])}")

        time_diff = np.abs(altimetry_data.time.values - date_np)  # Ensure date_np is numpy.datetime64

        # Debugging: Print the computed time_diff
        print(f"Time Differences: {time_diff[:5]}")

        if np.any(time_diff <= tolerance):
            closest_time_index = np.argmin(time_diff)
            vhm0_data = altimetry_data.isel(time=closest_time_index)['VHM0']
            vhm0_feature = vhm0_data.mean(dim=['latitude', 'longitude']).values.item()
            aligned_altimetry.append(vhm0_feature)
        else:
            aligned_altimetry.append(np.nan)
    return np.array(aligned_altimetry)

aligned_altimetry_features = align_altimetry_data(sequence_end_dates_np, normalized_altimetry_dataset)

# Check for NaNs and lengths
print(f"Length of aligned_wave_heights: {len(aligned_wave_heights)}")
print(f"Length of aligned_wind_speeds: {len(aligned_wind_speeds)}")
print(f"Length of aligned_altimetry_features: {len(aligned_altimetry_features)}")
print(f"Length of sequences: {len(sequences)}")

# Initialize a full valid indices array with False
valid_indices_full = np.zeros(len(sequences), dtype=bool)
valid_indices_full[:len(sequence_end_dates_np)] = ~np.isnan(aligned_wave_heights) & ~np.isnan(aligned_wind_speeds) & ~np.isnan(aligned_altimetry_features)

# Debugging: Print valid_indices_full length and sample
print(f"Length of valid_indices_full: {len(valid_indices_full)}")
print(f"Sample of valid_indices_full: {valid_indices_full[:10]}")

filtered_sequences = sequences[valid_indices_full]
filtered_wave_heights = aligned_wave_heights[valid_indices_full[:len(aligned_wave_heights)]]
filtered_wind_speeds = aligned_wind_speeds[valid_indices_full[:len(aligned_wind_speeds)]]
filtered_altimetry_features = aligned_altimetry_features[valid_indices_full[:len(aligned_altimetry_features)]]

# Check shapes and alignments
print(f"Filtered SAR Sequences shape: {filtered_sequences.shape}")
print(f"Filtered Altimetry Features shape: {filtered_altimetry_features.shape}")
print(f"Filtered Wave Heights shape: {filtered_wave_heights.shape}")
print(f"Filtered Wind Speeds shape: {filtered_wind_speeds.shape}")


# Model

## Model Architecture

In [None]:
# Architecture

from tensorflow.keras.layers import concatenate, Input, ConvLSTM2D, Flatten, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

# Define the RMSE metric
def root_mean_squared_error(y_true, y_pred):
    return K.sqrt(K.mean(K.square(y_pred - y_true)))

# Define input shapes for SAR, altimetry data, and wind speed
sar_input_shape = (6, 128, 128, 1)  # Adjust based on your SAR data
altimetry_input_shape = (1,)  # SWH is a single feature
wind_speed_input_shape = (1,)  # Wind Speed is also a single feature

# SAR image input branch
sar_input = Input(shape=sar_input_shape, name='sar_input')
convlstm1 = ConvLSTM2D(32, (3, 3), activation='relu', return_sequences=True)(sar_input)
convlstm2 = ConvLSTM2D(64, (3, 3), activation='relu', return_sequences=False)(convlstm1)
flattened_sar = Flatten()(convlstm2)

# Altimetry input branch
altimetry_input = Input(shape=altimetry_input_shape, name='altimetry_input')
dense_altimetry = Dense(32, activation='relu')(altimetry_input)  # Reduced size for single feature

# Wind speed input branch
wind_speed_input = Input(shape=wind_speed_input_shape, name='wind_speed_input')
dense_wind_speed = Dense(32, activation='relu')(wind_speed_input)  # Reduced size for single feature

# Concatenate inputs from all branches
combined = concatenate([flattened_sar, dense_altimetry, dense_wind_speed])

# Fully connected layers after combining inputs
dense1 = Dense(128, activation='relu')(combined)
dropout = Dropout(0.5)(dense1)
output = Dense(1, activation='linear')(dropout)  # Predicting a single value: wave height

# Final model assembly
model = Model(inputs=[sar_input, altimetry_input, wind_speed_input], outputs=output)

# Compile the model with MAE and RMSE as the metrics
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae', root_mean_squared_error])


## Model Training

In [None]:
# Training

from sklearn.model_selection import train_test_split
import numpy as np
import joblib

# Split the dataset into train+validation sets and a test set
(X_train_seq_filtered, X_test_seq_filtered,
 X_train_alt_filtered, X_test_alt_filtered,
 X_train_wind_filtered, X_test_wind_filtered,
 y_train_filtered, y_test_filtered) = train_test_split(
    filtered_sequences, filtered_altimetry_features, filtered_wind_speeds, filtered_wave_heights,
    test_size=0.2, random_state=42
)

# Further split the train+validation set into separate training and validation sets
(X_train_seq_filtered, X_validate_seq_filtered,
 X_train_alt_filtered, X_validate_alt_filtered,
 X_train_wind_filtered, X_validate_wind_filtered,
 y_train_filtered, y_validate_filtered) = train_test_split(
    X_train_seq_filtered, X_train_alt_filtered, X_train_wind_filtered, y_train_filtered,
    test_size=0.25, random_state=42  # Splits the 80% of data into 60% training and 20% validation
)

# Train your model using the training set and validate using the validation set
history_filtered = model.fit(
    [np.array(X_train_seq_filtered), np.array(X_train_alt_filtered), np.array(X_train_wind_filtered)],
    y_train_filtered,  # Inputs and targets for training
    validation_data=(
        [np.array(X_validate_seq_filtered), np.array(X_validate_alt_filtered), np.array(X_validate_wind_filtered)],
        y_validate_filtered),  # Inputs and targets for validation
    epochs=100,
    batch_size=16
)

# Save the training history
history_file = '/content/drive/MyDrive/Models/history_filtered_sixsequence.pkl'
with open(history_file, 'wb') as f:
    joblib.dump(history_filtered.history, f)

# Save the model
model_file = '/content/drive/MyDrive/Models/my_model_with_r2_sixsequence.h5'
model.save(model_file)

# After training, evaluate the model on the test set which is completely unseen
test_metrics = model.evaluate(
    [np.array(X_test_seq_filtered), np.array(X_test_alt_filtered), np.array(X_test_wind_filtered)],
    y_test_filtered, verbose=0
)

# Extract the metrics
test_loss, test_mae, test_rmse, test_r2 = test_metrics

print(f"Test Loss: {test_loss}")
print(f"Test MAE: {test_mae}")
print(f"Test RMSE: {test_rmse}")
print(f"Test R-squared: {test_r2}")


# XAI

## Occlusion Sensitivity

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from sklearn.metrics import r2_score

# Load the model
model_file = '/content/drive/MyDrive/Models/my_model_with_r2.h5'
model = load_model(model_file, custom_objects={'root_mean_squared_error': root_mean_squared_error, 'r_squared': r_squared})

# Load the data
# Ensure you have your data loaded here, e.g., filtered_sequences, filtered_altimetry_features, filtered_wave_heights, filtered_wind_speeds

def apply_occlusion(input_data, patch_size=10, stride=5):
    input_shape = input_data[0].shape
    occluded_data = []
    for i in range(0, input_shape[1] - patch_size + 1, stride):
        for j in range(0, input_shape[2] - patch_size + 1, stride):
            occluded_sample = np.copy(input_data)
            occluded_sample[0, i:i + patch_size, j:j + patch_size, :] = 0
            occluded_data.append(occluded_sample)
    return np.array(occluded_data)

def compute_occlusion_impact(model, input_data, true_label, patch_size=10, stride=5):
    base_prediction = model.predict(input_data)
    occluded_data = apply_occlusion(input_data, patch_size, stride)

    occlusion_impacts = []
    for occluded_sample in occluded_data:
        occluded_prediction = model.predict(occluded_sample)
        impact = np.abs(base_prediction - occluded_prediction).mean()
        occlusion_impacts.append(impact)

    occlusion_impacts = np.array(occlusion_impacts)
    occlusion_map = occlusion_impacts.reshape((input_data[0].shape[1] - patch_size) // stride + 1,
                                              (input_data[0].shape[2] - patch_size) // stride + 1)
    return occlusion_map, base_prediction

# Example to visualize the occlusion map
sample_index = 0  # Index of the sample to check
sample_data = [np.array([X_test_seq_filtered[sample_index]]),
               np.array([X_test_alt_filtered[sample_index]]),
               np.array([X_test_wind_filtered[sample_index]])]
true_label = y_test_filtered[sample_index]

# Compute occlusion impact
occlusion_map, base_prediction = compute_occlusion_impact(model, sample_data, true_label)

# Normalize occlusion map for better visualization
occlusion_map = occlusion_map / np.max(occlusion_map)

# Visualize the occlusion impact for SAR data
plt.figure(figsize=(10, 8))
plt.imshow(occlusion_map, cmap='jet')
plt.colorbar()
plt.title(f'Occlusion Impact for SAR Data for Sample {sample_index + 1}')
plt.show()


## Integrated Gradients

In [None]:
# INTERGRATED GRADIENTS

from tf_keras_vis.utils.scores import CategoricalScore
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear

# Define a score function
score_function = CategoricalScore(0)

# Initialize the saliency object
saliency = Saliency(model, model_modifier=ReplaceToLinear(), clone=False)

# Function to compute integrated gradients
def compute_integrated_gradients(input_data):
    return saliency(score_function, input_data, smooth_samples=20, smooth_noise=0.20)

# Check integrated gradients for a few samples
for i in range(3):  # Check for first 3 samples
    sample = [np.array(X_test_seq_filtered[i:i+1]), np.array(X_test_alt_filtered[i:i+1]), np.array(X_test_wind_filtered[i:i+1])]
    integrated_gradients = compute_integrated_gradients(sample)

    # Visualize the integrated gradients for SAR data
    for t in range(integrated_gradients[0].shape[1]):
        plt.figure(figsize=(10, 8))
        plt.imshow(integrated_gradients[0][0, t, :, :], cmap='jet')
        plt.colorbar()
        plt.title(f'Integrated Gradients for SAR Data at Time Step {t+1}')
        plt.show()

    # Visualize the integrated gradients for Altimetry data
    plt.figure(figsize=(10, 8))
    plt.plot(integrated_gradients[1][0])
    plt.title('Integrated Gradients for Altimetry Data')
    plt.show()

    # Visualize the integrated gradients for Wind Speed data
    plt.figure(figsize=(10, 8))
    plt.plot(integrated_gradients[2][0])
    plt.title('Integrated Gradients for Wind Speed Data')
    plt.show()
