In [None]:
from scipy.ndimage.filters import uniform_filter
from scipy.ndimage.measurements import variance

def lee_filter(img, size, channel=0):
    img_mean = uniform_filter(img[:,:,channel], (size, size))
    img_sqr_mean = uniform_filter(img[:,:,channel]**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = variance(img[:,:,channel])

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img[:,:,channel] - img_mean)
    return img_output

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from datasets import Dataset, load_from_disk
import pytorch_lightning as pl
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.adapterModel import LitSamModel
from utils.statistics import calculate_correlation
from model.samDataset import SAMDataset3

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
#load dataset
test_dataset = load_from_disk(os.path,join(DATA_DIR,'datasetTestFinal'))

In [None]:
test_image = test_dataset[0]['image']
test_image = np.array(test_image)
print(test_image.shape)

In [None]:
from dataprocessing.rcsHandlingFunctions import _rescale

def create_image(dataset, index):
    item = dataset[index]
    VH0 = np.array(item["VH0"])
    VH1 = np.array(item["VH1"])
    VV0 = np.array(item["VV0"])
    VV1 = np.array(item["VV1"])
    dem = np.array(item["dem"])
    slope = np.array(item["slope"])

    a = _rescale(VH1 - VH0, 0, .25)

    b = _rescale(VV1 - VV0, 0, .25)

    w = _rescale(a - b, 0, 1)

    r = w*VH0 + (1 - w)*VV0

    g = w*VH1 + (1 - w)*VV1

    image = np.stack([r, g, dem], axis=2)

    return image

In [None]:
test_image = create_image(test_dataset, 2)

In [None]:
#visualize the image
test_image = np.asarray(test_image).astype(np.float32)
plt.imshow(test_image)

In [None]:
#apply lee filter to the image
test_image = test_dataset[1]['image']
test_image = np.array(test_image)
test_image_filtered = np.zeros_like(test_image)
test_image_filtered[:,:,0] = lee_filter(test_image, 3, channel=0)
test_image_filtered[:,:,1] = lee_filter(test_image, 3, channel=1)
test_image_filtered[:,:,2] = test_image[:,:,2]  # Assuming the third channel is not filtered


#visualize the images in a grid
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(test_image)
axs[0].set_title('Original Image')
axs[1].imshow(test_image_filtered)
axs[1].set_title('Filtered Image')


In [None]:
#apply lee filter to the image
test_image = test_dataset[1]['image']
test_image = np.array(test_image)
test_image_filtered = np.zeros_like(test_image)
test_image_averaged = np.zeros_like(test_image)
test_image_averaged[:,:,0] = (test_image[:,:,0] + test_image[:,:,1]) / 2  # Mix the first two layers
test_image_averaged[:,:,1] = (test_image[:,:,0] + test_image[:,:,1]) / 2  # Mix the first two layers
test_image_averaged[:,:,2] = test_image[:,:,2]  # Assuming the third channel is not filtered
test_image_filtered[:,:,0] = lee_filter(test_image_averaged, 3, channel=0)
test_image_filtered[:,:,1] = lee_filter(test_image_averaged, 3, channel=1)
test_image_filtered[:,:,2] = test_image_averaged[:,:,2]  # Assuming the third channel is not filtered


#visualize the images in a grid
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(test_image)
axs[0].set_title('Original Image')
axs[1].imshow(test_image_filtered)
axs[1].set_title('Filtered Image')


In [None]:
def lee_filter_2(img, size, num_looks = 100, channel=0):
    """
    Correctly applies the Lee filter to a SAR image.

    Args:
        img (np.ndarray): The input image (should be intensity data).
        size (int): The size of the sliding window.
        num_looks (float): The number of looks, used to estimate noise variance.
        channel (int): The channel to process if the image is multi-channel.

    Returns:
        np.ndarray: The filtered image.
    """
    # Extract the specified channel and convert to float for processing
    img_channel = img[:, :, channel]
    
    # Calculate local mean and variance using a sliding window
    img_mean = uniform_filter(img_channel, size=size)
    img_sqr_mean = uniform_filter(img_channel**2, size=size)
    local_variance = img_sqr_mean - img_mean**2
    
    # --- The key correction is here ---
    # We estimate the noise variance based on the number of looks (L).
    # The coefficient of variation (Cu) for multi-look intensity data is 1 / sqrt(L).
    # Noise variance is local_mean^2 * Cu^2.
    cu = 1.0 / np.sqrt(num_looks)
    noise_variance = img_mean**2 * cu**2
    
    # Calculate the adaptive weight 'K'
    # The denominator must handle cases where local_variance is close to zero.
    K = np.zeros_like(img_channel)
    den = local_variance + noise_variance
    
    # Avoid division by zero
    non_zero_den_mask = den > 1e-6
    K[non_zero_den_mask] = (local_variance[non_zero_den_mask] - noise_variance[non_zero_den_mask]) / den[non_zero_den_mask]
    K[K < 0] = 0 # Ensure weights are not negative
    
    # Apply the final filter formula
    img_output = img_mean + K * (img_channel - img_mean)
    
    return img_output

In [None]:
def estimate_enl(image_patch):
    """
    Estimates the Equivalent Number of Looks (ENL) from a homogeneous image patch.

    Args:
        image_patch (np.ndarray): A NumPy array representing a homogeneous area.

    Returns:
        float: The estimated ENL.
    """
    # Calculate the mean and variance of the patch
    patch_mean = np.mean(image_patch)
    patch_variance = np.var(image_patch)

    # Ensure variance is not zero to avoid division errors
    if patch_variance == 0:
        return np.inf  # Return infinity if variance is zero
    
    enl = (patch_mean**2) / patch_variance
    return enl

In [None]:
test_image = test_dataset[1]['image']
test_image = np.array(test_image)
enl = estimate_enl(test_image)
print(f"Estimated ENL: {enl}")

In [None]:
#apply lee filter to the image
num_looks = 200  # Example value for number of looks
test_image = test_dataset[1]['image']
test_image = np.array(test_image)
test_image_filtered = np.zeros_like(test_image)
test_image_filtered[:,:,0] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=0)
test_image_filtered[:,:,1] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=1)
test_image_filtered[:,:,2] = test_image[:,:,2]  # Assuming the third channel is not filtered


#visualize the images in a grid
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(test_image)
axs[0].set_title('Original Image')
axs[1].imshow(test_image_filtered)
#title and num looks
axs[1].set_title(f'Filtered Image (Num Looks: {num_looks})')

#apply lee filter to the image
num_looks = 500  # Example value for number of looks
test_image_filtered = np.zeros_like(test_image)
test_image_filtered[:,:,0] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=0)
test_image_filtered[:,:,1] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=1)
test_image_filtered[:,:,2] = test_image[:,:,2]  # Assuming the third channel is not filtered


#visualize the images in a grid
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(test_image)
axs[0].set_title('Original Image')
axs[1].imshow(test_image_filtered)
#title and num looks
axs[1].set_title(f'Filtered Image (Num Looks: {num_looks})')

#apply lee filter to the image
num_looks = 3  # Example value for number of looks
test_image_filtered = np.zeros_like(test_image)
test_image_filtered[:,:,0] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=0)
test_image_filtered[:,:,1] = lee_filter_2(test_image, 5, num_looks= num_looks, channel=1)
test_image_filtered[:,:,2] = test_image[:,:,2]  # Assuming the third channel is not filtered


#visualize the images in a grid
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(test_image)
axs[0].set_title('Original Image')
axs[1].imshow(test_image_filtered)
#title and num looks
axs[1].set_title(f'Filtered Image (Num Looks: {num_looks})')

In [None]:
def lee_filter(img, size, channel=0):
    img_mean = uniform_filter(img[:,:,channel], (size, size))
    img_sqr_mean = uniform_filter(img[:,:,channel]**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = variance(img[:,:,channel])

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img[:,:,channel] - img_mean)
    return img_output

In [None]:
# Load the DataFrame from the file 
df_loaded = pd.read_pickle(os.path.join(DATA_DIR, "test_df_sam.pkl"))

In [None]:
rcs_list = df_loaded.loc[0, 'rcs']
dem = df_loaded.loc[0, 'dem']

In [None]:
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
import dataprocessing.rcsHandlingFunctions as rcs

print(f"RCS: {rcs_list.shape}, DEM: {dem.shape}")

In [None]:
vv1 = rcs_list[0]
vv2 = rcs_list[1]
vh1 = rcs_list[2]
vh2 = rcs_list[3]

In [None]:
rcs_reshaped = np.stack(rcs_list, axis=-1)
print(rcs_reshaped.shape)  # should output (364, 364, 4)

In [None]:
enl = estimate_enl(vh2)
print(f"Estimated ENL: {enl}")

In [None]:
filterd_rcs_list = []
for i in range(4):
    enl = estimate_enl(rcs_reshaped[:,:,i])
    filtered_rcs = lee_filter(rcs_reshaped, 5, channel=i)
    filterd_rcs_list.append(filtered_rcs)

rcs_filtered = np.stack(filterd_rcs_list, axis=-1)
print(f"Filtered RCS shape: {rcs_filtered.shape}")  # should output (364, 364, 4)

In [None]:
vv1 = rcs_list[0]
vv1_filterd = rcs_filtered[:, :, 0]

vv1_rescaled = rcs.vv1 = rcs._rescale(vv1, -23, -3)
vv1_filterd_rescaled = rcs._rescale(vv1_filterd, -23, -3)

#visualze the original and filtered RCS
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(vv1_rescaled, cmap='gray')
axs[0].set_title('Original VV1 RCS')
axs[1].imshow(vv1_filterd_rescaled, cmap='gray')
axs[1].set_title('Filtered VV1 RCS')    

In [None]:
vv1 = rcs_list[0]
vv1_filterd = rcs_filtered[:, :, 0]

vv1_rescaled = rcs.vv1 = rcs._rescale(vv1, -23, -3)
vv1_filterd_rescaled = rcs._rescale(vv1_filterd, -23, -3)

vv2 = rcs_list[1]
vv2_filterd = rcs_filtered[:, :, 1]
vv2_rescaled = rcs.vv2 = rcs._rescale(vv2, -23, -3)
vv2_filterd_rescaled = rcs._rescale(vv2_filterd, -23, -3)

image_vv1 = np.stack((vv1_rescaled, vv2_rescaled, vv1_rescaled), axis=-1)
image_vv1_filterd = np.stack((vv1_filterd_rescaled, vv2_filterd_rescaled, vv1_filterd_rescaled), axis=-1)

#visualize the original and filtered RCS
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(image_vv1)
axs[0].set_title('Original VV RCS')
axs[1].imshow(image_vv1_filterd)
axs[1].set_title('Filtered VV RCS')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Get_gradient_nopadding(nn.Module):
    def __init__(self):
        super(Get_gradient_nopadding, self).__init__()
        kernel_v = [[0, -1, 0],
                    [0, 0, 0],
                    [0, 1, 0]]
        kernel_h = [[0, 0, 0],
                    [-1, 0, 1],
                    [0, 0, 0]]
        kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
        kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
        self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
        self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False)

    def forward(self, x):
        x_list = []
        for i in range(2):
            x_i = x[:, i]
            x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
            x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
            x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
            x_list.append(x_i)
        x_list.append(x[:, 2].unsqueeze(1))  # Append the third channel without modification

        #print(x_list[1]-x_list[0])
        x = torch.cat(x_list, dim=1)
        return x


class Get_curvature(nn.Module):
    def __init__(self):
        super(Get_curvature, self).__init__()
        kernel_v1 = [[0, -1, 0],
                     [0, 0, 0],
                     [0, 1, 0]]
        kernel_h1 = [[0, 0, 0],
                     [-1, 0, 1],
                     [0, 0, 0]]
        kernel_h2 = [[0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0],
                     [1, 0, -2, 0, 1],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]
        kernel_v2 = [[0, 0, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, -2, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 1, 0, 0]]
        kernel_w2 = [[1, 0, -1],
                     [0, 0, 0],
                     [-1, 0, 1]]
        kernel_h1 = torch.FloatTensor(kernel_h1).unsqueeze(0).unsqueeze(0)
        kernel_v1 = torch.FloatTensor(kernel_v1).unsqueeze(0).unsqueeze(0)
        kernel_v2 = torch.FloatTensor(kernel_v2).unsqueeze(0).unsqueeze(0)
        kernel_h2 = torch.FloatTensor(kernel_h2).unsqueeze(0).unsqueeze(0)
        kernel_w2 = torch.FloatTensor(kernel_w2).unsqueeze(0).unsqueeze(0)
        self.weight_h1 = nn.Parameter(data=kernel_h1, requires_grad=False)
        self.weight_v1 = nn.Parameter(data=kernel_v1, requires_grad=False)
        self.weight_v2 = nn.Parameter(data=kernel_v2, requires_grad=False)
        self.weight_h2 = nn.Parameter(data=kernel_h2, requires_grad=False)
        self.weight_w2 = nn.Parameter(data=kernel_w2, requires_grad=False)

    def forward(self, x):
        x_list = []
        for i in range(2):
            x_i = x[:, i]
            x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v1, padding=1)
            x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h1, padding=1)
            x_i_v2 = F.conv2d(x_i.unsqueeze(1), self.weight_v2, padding=2)
            x_i_h2 = F.conv2d(x_i.unsqueeze(1), self.weight_h2, padding=2)
            x_i_w2 = F.conv2d(x_i.unsqueeze(1), self.weight_w2, padding=1)
            x_i = x[:, i]
            x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v1, padding=1)
            x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h1, padding=1)
            x_i_v2 = F.conv2d(x_i.unsqueeze(1), self.weight_v2, padding=2)
            x_i_h2 = F.conv2d(x_i.unsqueeze(1), self.weight_h2, padding=2)
            x_i_w2 = F.conv2d(x_i.unsqueeze(1), self.weight_w2, padding=1)
            sum = torch.pow((torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1), 3 / 2)
            fh = torch.mul((torch.pow(x_i_v, 2) + 1), x_i_h2) - 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul(
                (torch.pow(x_i_h, 2) + 1), x_i_v2)
            x_i = torch.div(fh, sum + 1e-10)
            x_list.append(x_i)
            #sum = (torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2)) * 2
            #fh = torch.mul(torch.pow(x_i_v, 2), x_i_h2) - 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul(
            #    torch.pow(x_i_h, 2), x_i_v2)
            #x_i = torch.div(torch.abs(fh), sum + 1e-10)
            #x_list.append(x_i)
            
        x_list.append(x[:, 2].unsqueeze(1))  # Append the third channel without modification
        x = torch.cat(x_list, dim=1)
        return x


class FeatureEncoder(nn.Module):
    def __init__(self, out_dims):
        super(FeatureEncoder, self).__init__()

        self.conv1 = nn.Conv2d(3, out_dims[0], kernel_size=3, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_dims[0], out_dims[0], kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(out_dims[0], out_dims[1], kernel_size=3, padding=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(out_dims[1], out_dims[1], kernel_size=3, padding=1)
        self.relu4 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(out_dims[1], out_dims[2], kernel_size=3, padding=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.conv6 = nn.Conv2d(out_dims[2], out_dims[2], kernel_size=3, padding=1)
        self.relu6 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv7 = nn.Conv2d(out_dims[2], out_dims[3], kernel_size=3, padding=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.conv8 = nn.Conv2d(out_dims[3], out_dims[3], kernel_size=3, padding=1)
        self.relu8 = nn.ReLU(inplace=True)

    def forward(self, x):
        # Stage 1
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool1(x)
        x1 = x

        # Stage 2
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.maxpool2(x)
        x2 = x

        # Stage 3
        x = self.conv5(x)
        x = self.relu5(x)
        x = self.conv6(x)
        x = self.relu6(x)
        x = self.maxpool3(x)
        x3 = x

        # Stage 4
        x = self.conv7(x)
        x = self.relu7(x)
        x = self.conv8(x)
        x = self.relu8(x)
        x4 = x

        return x1, x2, x3, x4


class PMD_features(nn.Module):
    def __init__(self, out_dims, gradient = True):
        super(PMD_features, self).__init__()
        if gradient:
            self.PMD_head = Get_gradient_nopadding()
        else:
            self.PMD_head = Get_curvature()
        # self.feature_ext = FeatureEncoder(out_dims)

    def forward(self, images):
        PMD_images = self.PMD_head(images)
        # PMD_feature = self.feature_ext(PMD_images)

        return PMD_images


# class Adapter(nn.Module):
#     def __init__(self, out_dims):
#         super(Adapter, self).__init__()
#         self.PMD_head = Get_gradient_nopadding()
#         self.feature_ext = FeatureEncoder(out_dims)
#
#     def forward(self, images):
#         PMD_images = self.PMD_head(images)
#         PMD_feature = self.feature_ext(PMD_images)
#
#         return PMD_feature

In [None]:
import matplotlib.pyplot as plt
image = test_dataset[0]['image']
image = np.array(image)
#show layers indipendently
plt.figure(figsize=(12, 6))
for i in range(image.shape[2]):
    plt.subplot(1, 3, i + 1)
    plt.imshow(image[:, :, i], cmap='gray')
    plt.axis('off')
plt.show()

In [None]:
import matplotlib.pyplot as plt

pmd = PMD_features([32, 64, 128, 256])
image = create_image(test_dataset, 0)
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
output = pmd(image_tensor)

plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(output[0].permute(1, 2, 0).cpu())

In [None]:
import matplotlib.pyplot as plt

pmd = PMD_features([32, 64, 128, 256])
image = test_dataset[0]['image']
#Mix the first two layers of image before calculating pmd
image = np.array(image)
image[:,:,0] = (image[:,:,0] + image[:,:,1]) / 2
image[:,:,1] = (image[:,:,0] + image[:,:,1]) / 2
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
output = pmd(image_tensor)

plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(output[0].permute(1, 2, 0).cpu())

In [None]:
import matplotlib.pyplot as plt

pmd = PMD_features([32, 64, 128, 256], gradient=False)
image = test_dataset[1]['image']
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
output = pmd(image_tensor)
# Convert the output tensor to a NumPy array
out_np = output[0].permute(1, 2, 0).cpu().numpy()

# Apply min-max normalization to scale the values between 0 and 1
out_norm = (out_np - out_np.min()) / (out_np.max() - out_np.min())

out_norm = output * -1  # Assuming you want to invert the output for visualization
out_norm = out_norm[0].permute(1, 2, 0).cpu().numpy()

plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(output[0].permute(1, 2, 0).cpu())

plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plt.imshow(image)  # original image
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(out_norm)  # normalized PMD output
plt.axis('off')
plt.show()

In [None]:
import numpy as np

def add_speckle_noise_gamma(image, num_looks=1, amplitude=False):
    """
    Adds simulated speckle noise to a SAR image using a Gamma distribution.

    Args:
        image (np.ndarray): The input SAR image. Can be intensity or amplitude.
        num_looks (float): The number of looks used to control the noise level. 
                           A smaller value (e.g., 1) produces more speckle. 
                           Must be > 0.
        amplitude (bool): Set to True if the input image is in amplitude, 
                          False if it is in intensity.

    Returns:
        np.ndarray: The image with added speckle noise.
    """
    if num_looks <= 0:
        raise ValueError("Number of looks must be greater than 0.")
        
    # Get image dimensions
    rows, cols = image.shape

    # If the input is amplitude, convert it to intensity first.
    # The Gamma model for speckle applies to intensity.
    if amplitude:
        intensity_image = image**2
    else:
        intensity_image = image

    # Generate a noise matrix from a Gamma distribution
    # The shape parameter 'k' is the number of looks.
    # The scale parameter 'theta' is 1 / k.
    # The mean of this distribution is k * (1/k) = 1, which ensures
    # the overall image intensity is not biased.
    noise = np.random.gamma(shape=num_looks, scale=1.0/num_looks, size=(rows, cols))

    # Multiply the intensity image by the noise matrix
    noisy_intensity = intensity_image * noise
    
    # If the original input was amplitude, convert the noisy intensity back
    if amplitude:
        noisy_image = np.sqrt(noisy_intensity)
    else:
        noisy_image = noisy_intensity
        
    return noisy_image

In [None]:
import math
class GF(nn.Module):
    def __init__(self, nbins=9, pool=7, kensize=5, img_size=224, patch_size=16):
        super(GF, self).__init__()
        self.nbins = nbins
        self.pool = pool
        self.pi = math.pi
        self.img_size = img_size
        self.patch_size = patch_size
        self.k = kensize

        # def creat_gauss_kernel(r=1, sigma=-1):
        #     if sigma <= 0:
        #         sigma = 0.3 * ((2*r+1 - 1) * 0.5 - 1) + 0.8
        #
        #     X = np.linspace(-r, r, 2*r+1)
        #     Y = np.linspace(-r, r, 2*r+1)
        #     x, y = np.meshgrid(X, Y)
        #     x0 = 0
        #     y0 = 0
        #     gauss = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
        #
        #     M_13 = np.concatenate([np.ones([r, 2*r+1]), np.zeros([r+1, 2*r+1])], axis=0)
        #     M_23 = np.concatenate([np.zeros([r+1, 2 * r + 1]), np.ones([r, 2 * r + 1])], axis=0)
        #
        #     M_11 = np.concatenate([np.ones([2*r+1, r]), np.zeros([2*r+1, r+1])], axis=1)
        #     M_21 = np.concatenate([np.zeros([2 * r + 1, r+1]), np.ones([2 * r + 1, r])], axis=1)
        #
        #     return torch.from_numpy((gauss*M_13)).float(), torch.from_numpy((gauss*M_23)).float(), torch.from_numpy((gauss*M_11)).float(), torch.from_numpy((gauss*M_21)).float()
        #
        def creat_kernel(r=1):

            M_13 = np.concatenate([np.ones([r+1, 2*r+1]), np.zeros([r, 2*r+1])], axis=0)
            M_23 = np.concatenate([np.zeros([r, 2 * r + 1]), np.ones([r+1, 2 * r + 1])], axis=0)

            M_11 = np.concatenate([np.ones([2*r+1, r+1]), np.zeros([2*r+1, r])], axis=1)
            M_21 = np.concatenate([np.zeros([2 * r + 1, r]), np.ones([2 * r + 1, r+1])], axis=1)


            return torch.from_numpy((M_13)).float(), torch.from_numpy((M_23)).float(), torch.from_numpy((M_11)).float(), torch.from_numpy((M_21)).float()

        M13, M23, M11, M21 = creat_kernel(self.k)

        weight_x1 = M11.view(1, 1, self.k*2+1, self.k*2+1)
        weight_x2 = M21.view(1, 1, self.k*2+1, self.k*2+1)

        weight_y1 = M13.view(1, 1, self.k*2+1, self.k*2+1)
        weight_y2 = M23.view(1, 1, self.k*2+1, self.k*2+1)

        self.register_buffer("weight_x1", weight_x1)
        self.register_buffer("weight_x2", weight_x2)
        self.register_buffer("weight_y1", weight_y1)
        self.register_buffer("weight_y2", weight_y2)


    @torch.no_grad()
    def forward(self, x):
        # input is RGB image with shape [B 3 H W]
        x = F.pad(x, pad=(self.k, self.k, self.k, self.k), mode="reflect") + 1e-2
        gx_1 = F.conv2d(
            x, self.weight_x1, bias=None, stride=1, padding=0, groups=1
        )
        gx_2 = F.conv2d(
            x, self.weight_x2, bias=None, stride=1, padding=0, groups=1
        )
        gy_1 = F.conv2d(
            x, self.weight_y1, bias=None, stride=1, padding=0, groups=1
        )
        gy_2 = F.conv2d(
            x, self.weight_y2, bias=None, stride=1, padding=0, groups=1
        )
        gx_rgb = torch.log((gx_1) / (gx_2))
        gy_rgb = torch.log((gy_1) / (gy_2))
        norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)

        # phase = torch.atan2(gx_rgb, gy_rgb)
        # phase = phase / self.pi * self.nbins  # [-9, 9]
        #
        # b, c, h, w = norm_rgb.shape
        # out = torch.zeros(
        #     (b, c, self.nbins, h, w), dtype=torch.float, device=x.device
        # )
        # phase = phase.view(b, c, 1, h, w)
        # norm_rgb = norm_rgb.view(b, c, 1, h, w)

        # plt.subplot(111)
        # plt.imshow(x[0].cpu().squeeze())
        # plt.axis('off')
        # plt.savefig("./origin.png", dpi=600, bbox_inches='tight',  pad_inches = 0.0)
        # plt.subplot(111)
        # plt.imshow(norm_rgb[0].cpu().squeeze())
        # plt.axis('off')
        # plt.savefig("./1.png", dpi=600, bbox_inches='tight',  pad_inches = 0.0)
        # plt.show()

        # out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)
        # # b, c, 9, h, w
        #
        # out = out.unfold(3, self.pool, self.pool)
        #
        # out = out.unfold(4, self.pool, self.pool)
        # # b, c, 9, 28, 28, self.pool, self.pool
        # out = out.sum(dim=[-1, -2])
        # # b, c, 9, 28, 28
        # out = torch.nn.functional.normalize(out, p=2, dim=2) # B 1 nbins H W
        # # b, c, 9, 28, 28
        # tmp_hog = out.flatten(1, 2)  # return B C H W
        # # b, 9, 28, 28
 

        return norm_rgb

In [None]:
def original_paper(image):
    new_image = np.zeros_like(image)
    for i in range(3):
        if i == 2:
            new_image[:,:,i] = np.array(image)[:,:,i]
            continue
        image_tensor = torch.from_numpy(np.array(image)[:,:,i]).unsqueeze(0).unsqueeze(0).float()
        # Instantiate GF with appropriate parameters
        sar_feature = GF(nbins=9, pool=7, kensize=1, img_size=512, patch_size=16)
        output = sar_feature(image_tensor)
        new_image[:,:,i] = output[0,0].cpu().numpy()
    return new_image

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

# Get an example image from your dataset (using index 1 here)
test_image = create_image(test_dataset, 2)
test_image = np.array(test_image).astype(np.float32)

# 1. Original image (normalize for display if needed)
original_disp = test_image / np.max(test_image)

# 2. Apply Lee filter (using your lee_filter function) to channels 0 and 1.
lee_filtered = np.zeros_like(test_image)
lee_filtered[:,:,0] = lee_filter(test_image, 3, channel=0)
lee_filtered[:,:,1] = lee_filter(test_image, 3, channel=1)
# Keep channel 2 unchanged (for example, DEM or other data)
lee_filtered[:,:,2] = test_image[:,:,2]
lee_disp = lee_filtered / np.max(lee_filtered)

# 3. Compute PMD image using your PMD_features class.
#    Convert image to a torch tensor with shape (1, C, H, W)
pmd_model = PMD_features([32, 64, 128, 256])
test_tensor = torch.from_numpy(test_image).permute(2, 0, 1).unsqueeze(0)
pmd_output = pmd_model(test_tensor)
# Convert PMD output to a numpy array for visualization.
# (Assuming the PMD head produces 3 channels: the first two are gradients and the third is passed through.)
pmd_image = pmd_output[0].permute(1, 2, 0).cpu().numpy()
pmd_disp = pmd_image / np.max(pmd_image)

orig_paper_img = original_paper(test_image)

# 4. Create a binary mask.
# Here we use a simple threshold on channel 0 of the Lee-filtered image
mask = test_dataset[2]['label']

fig, axs = plt.subplots(2, 3, figsize=(18, 12))
plt.rcParams['axes.titlesize'] = 20
images = [original_disp, lee_disp, pmd_disp, orig_paper_img, mask]
titles = ['Original Image', 'Lee Filtered Image', 'PMD Image', 'Original Gradient', 'Mask']

for i, ax in enumerate(axs.flat):
    if i < len(images):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(titles[i])
        ax.axis('off')
    else:
        ax.axis('off')

plt.tight_layout()
plt.show()