In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt

In [None]:
# Model definition
number_raster_layers = 9
number_pixels_layer = 19

class CNNRegressor(nn.Module):
    def __init__(self):
        super(CNNRegressor, self).__init__()
        self.conv1 = nn.Conv2d(number_raster_layers, 16, kernel_size=3, stride=1, padding=1)
        self.selu1 = nn.SELU()
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.selu2 = nn.SELU()
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.selu3 = nn.SELU()
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.selu4 = nn.SELU()
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.selu5 = nn.SELU()
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.selu6 = nn.SELU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512 * number_pixels_layer * number_pixels_layer, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.selu1(x)
        x = self.conv2(x)
        x = self.selu2(x)
        x = self.conv3(x)
        x = self.selu3(x)
        x = self.conv4(x)
        x = self.selu4(x)
        x = self.conv5(x)
        x = self.selu5(x)
        x = self.conv6(x)
        x = self.selu6(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

model = CNNRegressor()

In [None]:
# Load model
path_model = '../../Data/Calibrated_models/global_regressor_V0.pth'
model.load_state_dict(torch.load(path_model))


In [None]:
# Cambia el modelo al modo de evaluación (si es necesario)
model.eval()

In [None]:
# Load tensors
path_tensor_train = '../../Data/Calibrated_models/global_regressor_V0_tensor_y_train.pth'
y_train = torch.load(path_tensor_train)

path_tensor_test = '../../Data/Calibrated_models/global_regressor_V0_tensor_y_test.pth'
y_test = torch.load(path_tensor_test)

path_tensor_test = '../../Data/Calibrated_models/global_regressor_V0_test_tensor.pth'
test_tensor = torch.load(path_tensor_test)

path_tensor_training = '../../Data/Calibrated_models/global_regressor_V0_training_tensor.pth'
training_tensor = torch.load(path_tensor_training)

In [None]:
# Number of test sites
total_number_test_sites = test_tensor.shape[0]
print(total_number_test_sites)

In [None]:
# Select some sites to estimate the CNN importance for their pixels: 13, 81, 90, 132.

In [None]:
# Lets show the results for site "90"
input_tensor_number = 90 # 13, 81, 90, 132
input_tensor_raw = test_tensor[input_tensor_number].unsqueeze(0) #9

In [None]:
input_tensor_raw.shape

In [None]:
# Convert the input tensor to float to manipulate it with the CNN
input_tensor = input_tensor_raw.float()

In [None]:
# Perform a forward pass to obtain the CNN activations
output = model(input_tensor)
print(output)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np

# Land cover names
layer_names = [
    "a) Bare (%)", "b) Built-Up (%)", "c) Crops (%)", "d) Grass (%)", "e) Moss/Lichen (%)",
    "f) Permanent water (%)", "g) Seasonal water (%)", "h) Shrub (%)", "i) Tree (%)"
]

# Create a plot to show each land cover layer
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))

# Iterate over each layer
for i, ax in enumerate(axes.flatten()):
    # show land cover image
    im = ax.imshow(input_tensor[0, i, :, :], cmap='viridis')
    ax.set_title(layer_names[i], fontsize=14)
    ax.axis('off')

    # Create a new color bar.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)

    # Add color bar.
    cbar = plt.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

plt.tight_layout()
plt.show()

# Save image
fig.savefig('../../Figures/figA2.png', dpi=300)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb 

input_tensor_squeeze = input_tensor.squeeze(0)

# Define land cover color palette
color_map = {
    0: 'saddlebrown',   # Bare
    1: 'red',           # BuiltUp
    2: 'yellow',        # Crops
    3: 'lightgreen',    # Grass
    4: 'limegreen',     # MossLichen
    5: 'blue',          # PermanentWater
    6: 'cyan',          # SeasonalWater
    7: 'olive',         # Shrub
    8: 'darkgreen'      # Tree
}

# Estimate the land cover dominant layer
dominant_layer_indices = np.argmax(input_tensor_squeeze, axis=0)

# Create an empty RGB image
dominant_image = np.zeros((*dominant_layer_indices.shape, 3), dtype=np.uint8)

# Map each index to its corresponding color in the RGB image
for index, color in color_map.items():
    # Find where the dominant index is equal to the current layer index
    mask = dominant_layer_indices == index
    dominant_image[mask] = np.array(to_rgb(color)) * 255  # Convert color to RGB and adjust to a scale of 0-255


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb
import matplotlib.patches as mpatches

# Land cover colors
color_map = {
    0: 'saddlebrown',   # Bare
    1: 'red',           # BuiltUp
    2: 'yellow',        # Crops
    3: 'lightgreen',    # Grass
    4: 'limegreen',     # MossLichen
    5: 'blue',          # PermanentWater
    6: 'cyan',          # SeasonalWater
    7: 'olive',         # Shrub
    8: 'darkgreen'      # Tree
}

# land cover labels
labels = {
    0: "Bare",
    1: "Built-Up",
    2: "Crops",
    3: "Grass",
    4: "Moss/Lichen",
    5: "Permanent water",
    6: "Seasonal water",
    7: "Shrub",
    8: "Tree"
}

# Create a new figure
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(dominant_image)
ax.axis('on')

# Create a list of patches for the legend.
patches = [mpatches.Patch(color=to_rgb(color), label=labels[idx]) for idx, color in color_map.items()]

# Add legend
fig.legend(handles=patches, loc='lower center', bbox_to_anchor=(0.5, 0.1), ncol=3, frameon=False, fontsize=13)

# Adjust lengend's margins
plt.subplots_adjust(bottom=0.2)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

# save the figure
fig.savefig('../../Figures/figB3a.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()


In [None]:
# Create a plot showing CNN activations

In [None]:
def get_activation(layer, input, output):
    global activation
    activation = torch.relu(output)  # Use ReLU para visualizar mejor las activaciones

In [None]:
# Register a hook
hook = model.conv1.register_forward_hook(get_activation)

In [None]:
# Perform a forward pass to obtain the activations
output = model(input_tensor)

In [None]:
# Remove the hook after using it to prevent memory leaks
hook.remove()

In [None]:
# Verify that activations are defined
if 'activation' in globals():
    # Show feature plots
    num_plots = activation.shape[1]
    fig, axes = plt.subplots((num_plots + 3) // 4, 4, figsize=(12, (num_plots + 3) // 4 * 3))
    for i, ax in enumerate(axes.flat):
        if i < num_plots:
            ax.imshow(activation[0, i].detach().numpy(), cmap='gray')
            ax.axis('off')
        else:
            ax.axis('off')
    plt.show()
else:
    print("No activation data was recorded.")
    
# Save figure
fig.savefig('../../Figures/figA3.png', dpi=300, bbox_inches='tight')

In [None]:
# Get grad-CAM for each pixel

In [None]:
import torch
import torch.nn.functional as F
import random

class GradCAM:
    def __init__(self, model, layer):
        self.model = model
        self.layer = layer
        self.gradient = None
        self.activation = None

        self.hook_handles = []
        self.hook_handles.append(layer.register_forward_hook(self.save_activation))
        self.hook_handles.append(layer.register_backward_hook(self.save_gradient))

    def save_activation(self, module, input, output):
        self.activation = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradient = grad_output[0].detach()

    def __call__(self, x, index=None):
        # Set a fixed seed for reproducibility
        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            np.random.seed(seed)
            random.seed(seed)

        # Clear previous gradients and activations
        self.gradient = None
        self.activation = None
        
        output = self.model(x)
        if index is None:
            index = torch.argmax(output)

        self.model.zero_grad()
        output.backward(torch.ones_like(output), retain_graph=True)

        pooled_gradients = torch.mean(self.gradient, dim=[0, 2, 3])
        for i in range(pooled_gradients.size(0)):
            self.activation[:, i, :, :] *= pooled_gradients[i]
        
        heatmap = torch.mean(self.activation, dim=1).squeeze()
        heatmap = F.relu(heatmap)
        heatmap /= torch.max(heatmap)

        return heatmap

    def release(self):
        for handle in self.hook_handles:
            handle.remove()

In [None]:
# Use GradCAM on the model with the conv6 layer
grad_cam = GradCAM(model, model.conv6)

In [None]:
# Set a fixed seed to ensure reproducibility
seed = 42
# Get grad-CAM map
heatmap = grad_cam(input_tensor)
grad_cam.release()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Plot grad-CAM heatmap
heatmap_np = heatmap.cpu().numpy()  # Convert the tensor to a NumPy array for visualization

# Create the figure
fig = plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[20, 1]) 

ax = fig.add_subplot(gs[0])
cax = ax.imshow(heatmap_np, cmap='hot')

# Hide axis ticks and labels
ax.axis('off')
# Create the subplot for the color bar below the heatmap
cbar_ax = fig.add_subplot(gs[1])

# Add the color bar
cbar = fig.colorbar(cax, cax=cbar_ax, orientation='horizontal')
cbar_ax.xaxis.set_ticks_position('bottom')
cbar.ax.tick_params(labelsize=13)

# Save the figure
fig.savefig('../../Figures/figB1.png', dpi=300, bbox_inches='tight')

# Show the figure
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_rgb
from matplotlib.cm import ScalarMappable
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

normalized_heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

# Create an RGBA image for `dominant_image` where the alpha channel is adjusted according to grad-CAM
dominant_rgba_image = np.zeros((*dominant_image.shape[:2], 4), dtype=np.float32)  # Create RGBA image

for i in range(3):  # Copy the RGB channels
    dominant_rgba_image[..., i] = dominant_image[..., i] / 255.0  # Normalize and copy
dominant_rgba_image[..., 3] = normalized_heatmap  # Adjust the alpha channel using the heatmap

# Create image
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(dominant_rgba_image)
ax.axis('on')

# Define the color palette and labels for the legend of `dominant_image`
color_map = {
    "Bare": 'saddlebrown',
    "Built-Up": 'red',
    "Crops": 'yellow',
    "Grass": 'lightgreen',
    "Moss/Lichen": 'limegreen',
    "Permanent water": 'blue',
    "Seasonal water": 'cyan',
    "Shrub": 'olive',
    "Tree": 'darkgreen'
}
patches = [mpatches.Patch(color=to_rgb(color), label=label) for label, color in color_map.items()]

# Add the legend to the plot
fig.legend(handles=patches, loc='lower center', bbox_to_anchor=(0.5, 0.1), ncol=3, frameon=False, fontsize=13)

# Adjust the margins to make space for the legend
plt.subplots_adjust(bottom=0.2)

ax.xaxis.set_ticks([]) 
ax.yaxis.set_ticks([])

plt.show()
fig.savefig('../../Figures/figB3b.png', dpi=300, bbox_inches='tight')

In [None]:
# We will estimate the groups of important pixels that are together based on grad-CAM > 0 and the dominant land covers map

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, to_rgb
from matplotlib.cm import ScalarMappable
import matplotlib.patches as mpatches

# Ensure the grad-CAM heatmap is normalized
normalized_heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

# Create an RGBA image for `dominant_image` where the alpha channel is adjusted according to the heatmap
dominant_rgba_image = np.zeros((*dominant_image.shape[:2], 4), dtype=np.float32)  # Create RGBA image

for i in range(3):  # Copy the RGB channels
    dominant_rgba_image[..., i] = np.around(dominant_image[..., i] / 255.0, 7)  # Normalize and copy

########################################################
# Adjust the alpha channel using the heatmap with the threshold
threshold = 0.0
#######################################################

alpha_channel = np.where(normalized_heatmap > threshold, 1.0, 0)
dominant_rgba_image[..., 3] = alpha_channel  # Apply the modified alpha channel

# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
ax.imshow(dominant_rgba_image)
ax.axis('on')

# Define the color palette and labels for the `dominant_image` legend
color_map = {
    "Bare": 'saddlebrown',
    "BuiltUp": 'red',
    "Crops": 'yellow',
    "Grass": 'lightgreen',
    "MossLichen": 'limegreen',
    "PermanentWater": 'blue',
    "SeasonalWater": 'cyan',
    "Shrub": 'olive',
    "Tree": 'darkgreen'
}
patches = [mpatches.Patch(color=to_rgb(color), label=label) for label, color in color_map.items()]

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

# Add the legend for the land cover
legend = ax.legend(handles=patches, loc='upper left', title="Land Cover Types", bbox_to_anchor=(1.05, 1), borderaxespad=0.)

plt.show()

In [None]:
# Sanity check
alpha_mask = dominant_rgba_image[..., 3] == 1

# Create an output image that will be a copy of the original image
output_image = np.zeros_like(dominant_rgba_image)

# Copy only the pixels where alpha is 1
output_image[alpha_mask] = dominant_rgba_image[alpha_mask]

# Visualize the resulting image, which will now highlight the pixels with alpha = 1
plt.figure(figsize=(10, 6))
plt.imshow(output_image)
plt.title("Pixels with Alpha = 1")
plt.axis('on')
plt.show()

In [None]:
import numpy as np
from skimage import measure, color
import matplotlib.pyplot as plt

# Convert the RGB image (ignoring alpha) to a binary image where there are non-black pixels
binary_image = np.any(output_image[..., :3] > 0, axis=-1)  # True where there is color

# Apply connected component labeling on the binary image
labeled_image = measure.label(binary_image, connectivity=2)  # Diagonal connections included

In [None]:
# To ensure that we only use the RGB channels
if output_image.shape[-1] == 4:  # If it still includes the alpha channel
    output_image = output_image[..., :3]  # Take only the first three channels (RGB)

# Apply label2rgb again
from skimage import color

image_labeled = color.label2rgb(labeled_image, output_image, kind='overlay', bg_label=0)

# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
ax.imshow(image_labeled)
ax.axis('off')

plt.show()
fig.savefig('../../Figures/figB3c.png', dpi=300, bbox_inches='tight')