# The PolInSAR Course - June 3rd, 2024
# SAR Polarimetry (PolSAR) 
# Part 2: Eigenvalues of the Polarimetric Coherency Matrix and the Entropy/Anisotropy/Alpha decomposition

* Acquisition: Nkok (Gabon), DLR's F-SAR, L-band

* Path to images: /projects/data/polsar/

* SLC (single-look complex) images:
    * HH: slc_16afrisr0107_Lhh_tcal_test.rat
    * HV: slc_16afrisr0107_Lhv_tcal_test.rat
    * VH: slc_16afrisr0107_Lvh_tcal_test.rat
    * VV: slc_16afrisr0107_Lvv_tcal_test.rat

Tips:
- use a function that performs the multilook (correlation) operation on a moving window with (looksa x looksr) pixels in range - azimuth
- focus on a azimuth - range block within pixels [5000, 15000] and [0, 2000], respectively.

In [None]:
!pip install pooch
import pooch
print(pooch.__version__)

!pip install pysarpro
import pysarpro
print(pysarpro.__version__)

!pip install scipy
import scipy
print(scipy.__version__)

!pip install numpy
import numpy
print(numpy.__version__)

!pip install matplotlib
import matplotlib
print(matplotlib.__version__)
!pip install ipympl


!pip install cartopy
import cartopy
print(cartopy.__version__)

!pip install pyproj
import pyproj
print(pyproj.__version__)

!pip install pyresample
import pyresample
print(pyresample.__version__)

!pip install rasterio
import rasterio
print(rasterio.__version__)

print("the requirments is are satisfied ================================================================= ")

In [None]:
# --- Download exercise data & import reader function
from pysarpro import io, data
from pysarpro.io import rrat

#data.download_all(directory="/projects", pattern=r'^data/polsar')

# --- Import useful libaries, functions, and modules
import sys
sys.path.append('/projects/src/')
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import uniform_filter
%matplotlib widget

**Auxiliary functions**

`HSV_colormap_to_rgb`: Generates and HSV composite representation based on a given colormap.

In [None]:
def HSV_colormap_to_rgb(colormap, h, s, v):
    """
    Makes an HSV-like RGB representation based on the given colormap instead
    of 'hsv' colormap.
    
    See https://en.wikipedia.org/wiki/HSL_and_HSV

    Parameters
    ----------
    colormap : function
        Colormap function. Takes the values in 'h' array and returns an RGBA
        value for each point. The ones in matplotlib.cm should be compatible
    h : ndarray
        Hue values. Usually between 0 and 1.0.
    s : ndarray
        Saturation values. Between 0 and 1.0.
    v : ndarray
        Value values. Between 0 and 1.0.

    Returns
    -------
    rgb: ndarray
        An array with the same shape as input + (3,) representing the RGB.
    """
    # Generate color between given colormap (colormap(h)) and white (ones)
    # according to the given saturation
    tmp = (1-s)[..., np.newaxis]*np.ones(3) + s[..., np.newaxis] * colormap(h)[...,:3]
    # Scale it by value
    return v[..., np.newaxis] * tmp

`calculate_covariance`: Calculates the covariance between two images while performing a multi-looking operation.

In [None]:
def calculate_covariance(im1, im2, looksa, looksr):
    
     # ... apply definition
    corr = uniform_filter( np.real(im1*np.conj(im2)), [looksa, looksr] ) + \
        1j*uniform_filter( np.imag(im1*np.conj(im2)), [looksa, looksr] )
    
    # ... and back to main
    return corr

`calculate_eigenvalues_3`: Computes the eigenvalues of a 3x3 matrix analytically. 

In [None]:
def calculate_eigenvalues_3(T11, T12, T13, T22, T23, T33):

    # Calculate and order (from max to min) the eigenvalues of a 3x3 hermitian matrix in closed-form.
    # Inputs can be 2D az - rg (rows - columns).

    # get dimensions
    dims = T11.shape

    # calculate auxiliary quantities
    A = T11*T22 + T11*T33 + T22*T33 - T12*np.conj(T12) - T13*np.conj(T13) - T23*np.conj(T23)
    B = T11**2 - T11*T22 + T22**2 -T11*T33 -T22*T33 + T33**2 + 3*T12*np.conj(T12) + 3*T13*np.conj(T13) + 3*T23*np.conj(T23)

    DET = T11*T22*T33 - T33*T12*np.conj(T12) - T22*T13*np.conj(T13) - T11*T23*np.conj(T23) + T12*np.conj(T13)*T23 + np.conj(T12)*T13*np.conj(T23)  
    TR = T11 + T22 + T33 
    Z = 27*DET-9*A*TR + 2*TR**3 + np.sqrt((27*DET-9*A*TR + 2*TR**3)**2-4*B**3)
    
    del DET
    
    # ... and here they are:
    LA = ( 1/3.*TR + 2**(1/3.)*B/(3*Z**(1/3.)) + Z**(1/3.)/(3*2**(1/3.)) )
    LB = ( 1/3.*TR - (1+1j*np.sqrt(3))*B/(3*2**(2/3.)*Z**(1/3.)) - (1-1j*np.sqrt(3))*Z**(1/3.)/(6*2**(1/3.)) )
    LC = ( 1/3.*TR - (1-1j*np.sqrt(3))*B/(3*2**(2/3.)*Z**(1/3.)) - (1+1j*np.sqrt(3))*Z**(1/3.)/(6*2**(1/3.)) )
    
    # now order them:
    dumm = np.zeros((dims[0], dims[1], 3), 'float32')
    dumm [:, :, 0] = np.real(LA)
    dumm [:, :, 1] = np.real(LB)
    dumm [:, :, 2] = np.real(LC)
    
    del LA, LB, LC  
    
    L1 = np.max(dumm, axis = 2)
    L3 = np.min(dumm, axis = 2)
    L2 = np.sum(dumm, axis = 2) - L1 - L3
    
    del dumm
    
    return L1, L2, L3
    

`calculate_eigenvectors_3`: Computes the eigenvectors of a 3x3 matrix analytically. 

In [None]:
def calculate_eigenvectors_3(T11, T12, T13, T22, T23, T33, L1, L2, L3) :

    # Calculate the eigenvectors corresponding to the eigenvalues (L1, L2, L3)
    # of a 3x3 matrix 
    # Inputs can be 2D az - rg (rows - columns).

    # get dimensions
    dims = T11.shape    
    
    # first eigenvector - corresponds to the maximum eigenvalue L1
    U1 = np.ones((dims[0], dims[1], 3), 'complex64')
    U1[:, :, 0] = (L1 -T33)/np.conj(T13) + (((L1-T33)*np.conj(T12) + np.conj(T13)*T23)*np.conj(T23))/ \
                    (((T22-L1)*np.conj(T13) - np.conj(T12)*np.conj(T23))*np.conj(T13))
    U1[:, :, 1] = -((L1-T33)*np.conj(T12)+np.conj(T13)*T23) / ((T22-L1)*np.conj(T13) - np.conj(T12)*np.conj(T23))
    
    # second eigenvector - corresponds to the eigenvalue L2
    U2 = np.ones((dims[0], dims[1], 3), 'complex64')
    U2[:, :, 0] = (L2 -T33)/np.conj(T13) + (((L2-T33)*np.conj(T12) + np.conj(T13)*T23)*np.conj(T23))/ \
                    (((T22-L2)*np.conj(T13) - np.conj(T12)*np.conj(T23))*np.conj(T13))
    U2[:, :, 1] = -((L2-T33)*np.conj(T12)+np.conj(T13)*T23) / ((T22-L2)*np.conj(T13) - np.conj(T12)*np.conj(T23))
    
    # third eigenvector - corresponds to the minimum eigenvalue L3
    U3 = np.ones((dims[0], dims[1], 3), 'complex64')
    U3[:, :, 0] = (L3 -T33)/np.conj(T13) + (((L3-T33)*np.conj(T12) + np.conj(T13)*T23)*np.conj(T23))/ \
                    (((T22-L3)*np.conj(T13) - np.conj(T12)*np.conj(T23))*np.conj(T13))
    U3[:, :, 1] = -((L3-T33)*np.conj(T12)+np.conj(T13)*T23) / ((T22-L3)*np.conj(T13) - np.conj(T12)*np.conj(T23))   
    
    # normalize to get orthonormal eigenvectors
    norm1 = np.sqrt( np.abs(U1[:,:,0])**2 + np.abs(U1[:,:,1])**2 + np.abs(U1[:,:,2])**2)
    norm2 = np.sqrt( np.abs(U2[:,:,0])**2 + np.abs(U2[:,:,1])**2 + np.abs(U2[:,:,2])**2)    
    norm3 = np.sqrt( np.abs(U3[:,:,0])**2 + np.abs(U3[:,:,1])**2 + np.abs(U3[:,:,2])**2)        
    for nn in range(3):
        U1[:,:,nn] = U1[:,:,nn] / norm1
        U2[:,:,nn] = U2[:,:,nn] / norm2
        U3[:,:,nn] = U3[:,:,nn] / norm3
        
    del norm1, norm2, norm3     
    
    return U1, U2, U3


**Input parameters**

In [None]:
# path to the data
path = '/projects/data/polsar/'
# define the number of looks 
looksa = 7
looksr = 7

**Step 1: Load data**

In [None]:
slcHH = rrat(path + 'slc_16afrisr0107_Lhh_tcal_test.rat', block = [5000, 15000, 0, 2000])
slcVV = rrat(path + 'slc_16afrisr0107_Lvv_tcal_test.rat', block = [5000, 15000, 0, 2000])
slcHV = rrat(path + 'slc_16afrisr0107_Lhv_tcal_test.rat', block = [5000, 15000, 0, 2000])

In [None]:
# check shape
slcHH.shape

**Step 2: Calculate the necessary elements of the coherency matrix**

In [None]:
# -- compute the Pauli components
pauli1 = slcHH + slcVV
pauli2 = slcHH - slcVV
pauli3 = 2*slcHV

In [None]:
# -- compute the elements of the coherency matrix
T11 = calculate_covariance(pauli1, pauli1, looksa, looksr)
T22 = calculate_covariance(pauli2, pauli2, looksa, looksr)
T33 = calculate_covariance(pauli3, pauli3, looksa, looksr)
T12 = calculate_covariance(pauli1, pauli2, looksa, looksr)
T13 = calculate_covariance(pauli1, pauli3, looksa, looksr)
T23 = calculate_covariance(pauli2, pauli3, looksa, looksr)

In [None]:
# -- delete unused variables
del slcHH, slcVV, slcHV

In [None]:
del pauli1, pauli2, pauli3

**Step 3: Calculate eigenvalues**

In [None]:
lambda1, lambda2, lambda3 = calculate_eigenvalues_3(T11, T12, T13, T22, T23, T33)

In [None]:
# check shape
lambda1.shape

**Step 4: Calculate entropy**

In [None]:
# -- compute the probabilities associated with each eigenvalue
pr1 = lambda1 / (lambda1 + lambda2 + lambda3)
pr2 = lambda2 / (lambda1 + lambda2 + lambda3)
pr3 = lambda3 / (lambda1 + lambda2 + lambda3)

In [None]:
# -- compute the entropy
entropy = -(pr1*np.log10(pr1)/np.log10(3) + pr2*np.log10(pr2)/np.log10(3) + pr3*np.log10(pr3)/np.log10(3))

**Step 5: Calculate anisotropy** 

In [None]:
# -- compute the anisotropy (related to the minimum and intermediate eigenvalues)
# A = 0 when lambda2 = lambda3
# A = 1 when lambda2 >> lambda3 
anisotropy = (lambda2 - lambda3) / (lambda2 + lambda3)

**Step 6: Calculate eigenvectors**

In [None]:
# -- compute the eigenvectors
U1, U2, U3 = calculate_eigenvectors_3(T11, T12, T13, T22, T23, T33, lambda1, lambda2, lambda3)

In [None]:
# check shape
U1.shape

In [None]:
# -- delete unused variables
del T12, T23, T13

**Step 7: Calculate mean alpha angle**

In [None]:
# -- extract the alpha angles
alpha1 = np.arccos(abs(U1[:,:,0])) # [rad]
alpha2 = np.arccos(abs(U2[:,:,0]))
alpha3 = np.arccos(abs(U3[:,:,0]))

In [None]:
# -- delete unused variables
del U1, U2, U3

In [None]:
# -- compute the mean alpha angle
alpha_mean = alpha1*pr1 + alpha2*pr2 + alpha3*pr3 # [rad]
alpha_mean = np.degrees(alpha_mean) # [deg]

**Step 8: Plots!**

In [None]:
# Calculations for Paulis RGB:
# -- define the 3D array for the Pauli representation
dimaz = lambda1.shape[0]
dimrg = lambda1.shape[1]
rgb_pauli = np.zeros((dimrg, dimaz, 3), 'float32')
# -- fill the array, clipping the values between 0 and 2.5xmean(amplitude)
rgb_pauli[:,:,0] = np.clip(np.transpose(np.sqrt(abs(T22))), 0,2.5*np.mean(np.sqrt(abs(T22)))) # R : HH-VV 
rgb_pauli[:,:,1] = np.clip(np.transpose(np.sqrt(abs(T33))), 0,2.5*np.mean(np.sqrt(abs(T33)))) #  G : HV
rgb_pauli[:,:,2] = np.clip(np.transpose(np.sqrt(abs(T11))), 0,2.5*np.mean(np.sqrt(abs(T11)))) #  G : HV
# -- normalisation: values between 0 and 1
rgb_pauli[:,:,0] = rgb_pauli[:,:,0] / np.max(rgb_pauli[:,:,0])
rgb_pauli[:,:,1] = rgb_pauli[:,:,1] / np.max(rgb_pauli[:,:,1])
rgb_pauli[:,:,2] = rgb_pauli[:,:,2] / np.max(rgb_pauli[:,:,2])


In [None]:
# # Plot: Pauli RGB and eigenvalue probabilities

# plt.figure(figsize=(15, 6*4))
# plt.subplot(4,1,1)
# plt.imshow(rgb_pauli, aspect = 'auto')
# plt.colorbar()

# plt.subplot(4,1,2)
# plt.imshow(np.transpose(pr1), cmap = 'turbo', vmin =0, vmax=1,aspect = 'auto')
# plt.colorbar()

# plt.subplot(4,1,3)
# plt.imshow(np.transpose(pr2), cmap = 'turbo', vmin =0, vmax=1,aspect = 'auto')
# plt.colorbar()

# plt.subplot(4,1,4)
# plt.imshow(np.transpose(pr3), cmap = 'turbo', vmin =0, vmax=1,aspect = 'auto')
# plt.colorbar()

# plt.tight_layout()


In [None]:
# # Plot: H, A, alpha

# plt.figure(figsize= (15, 6*3))
# plt.subplot(3,1,1)
# plt.imshow(np.transpose(entropy), cmap = 'gray', vmin = 0, vmax =1, aspect = 'auto')
# cb = plt.colorbar()
# cb.set_label('H')

# plt.subplot(3,1,2)
# plt.imshow(np.transpose(anisotropy), cmap = 'turbo', vmin = 0, vmax =1, aspect = 'auto')
# cb = plt.colorbar()
# cb.set_label('A')

# plt.subplot(3,1,3)
# plt.imshow(np.transpose(alpha_mean), cmap = 'turbo', vmin = 0, vmax = 90, aspect = 'auto')
# cb = plt.colorbar()
# cb.set_label('mean alpha [deg]')

# plt.tight_layout()


In [None]:
# HSI Color Representation:

 HSI Color Representation:
- H (hue):  mean alpha angle
- S (saturation): 
     - Case 1: saturation = 1: always full colorscale
     - Case 2:  saturation = 1 - entropy
          - when entropy = 0: then saturation = 1: full colorscale
          - when entropy = 1: then saturation = 0: grayscale
- I (intensity): amplitude of total power

In [None]:
# Hue: mean alpha angle
# normalize the mean alpha angle: it has to be between 0 and 1 --> divide by 90 degrees
alpha_mean = alpha_mean / 90

# Import the colormap for plotting alpha
colormap = plt.colormaps.get('turbo')

# Intensity: normalize the amplitude
amp = np.sqrt(abs(T11) + abs(T22) + abs(T33))
amp = np.clip(amp, 0, 2.5*np.mean(amp))
amp = amp / np.max(amp)

# Saturation
# Case 1)
sat1 = np.ones_like(amp)
# Case 2)
sat2 = 1 - entropy

In [None]:
# Generate the HSV colormaps 

# Case 1
hsv_comp1 = HSV_colormap_to_rgb(colormap, alpha_mean, sat1, amp)

# Case 2
hsv_comp2 = HSV_colormap_to_rgb(colormap, alpha_mean, sat2, amp)


In [None]:
# -- delete unused variables
del amp, sat1, sat2

In [None]:
# # --- Plot: HSI representations

# plt.figure(figsize = (12,12))

# plt.subplot(2,1,1)
# plt.imshow(np.transpose(hsv_comp1, axes = (1,0,2)) , aspect = 'auto')

# plt.subplot(2,1,2)
# plt.imshow(np.transpose(hsv_comp2, axes = (1,0,2)) , aspect = 'auto')

# plt.tight_layout()

# PART 2📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙📙

## Requirments 

In [None]:
# requirments 

import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import ViTModel, ViTFeatureExtractor, pipeline
import gradio as gr

In [None]:
# # works fine 


# # Plot each parameter individually
# fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# # Alpha Mean (degrees)
# im1 = axes[0].imshow(alpha_mean, cmap='turbo', vmin=0, vmax=90)
# axes[0].set_title('Alpha Mean [°]')
# plt.colorbar(im1, ax=axes[0])

# # Entropy (0 to 1)
# im2 = axes[1].imshow(entropy, cmap='viridis', vmin=0, vmax=1)
# axes[1].set_title('Entropy')
# plt.colorbar(im2, ax=axes[1])

# # Anisotropy (0 to 1)
# im3 = axes[2].imshow(anisotropy, cmap='plasma', vmin=0, vmax=1)
# axes[2].set_title('Anisotropy')
# plt.colorbar(im3, ax=axes[2])

# plt.tight_layout()
# plt.show()

In [None]:
# 1. First, import all required libraries
import numpy as np
import torch
from transformers import ViTImageProcessor, ViTModel  # Using ViTImageProcessor instead of deprecated ViTFeatureExtractor
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, widgets
from IPython.display import display
from transformers import pipeline

# 2. Load your SAR data (replace with your actual data loading)
# Assuming you already have these arrays from your previous processing:
# alpha_mean, entropy, anisotropy

# 3. Stack and normalize the layers (your existing code)
sar_features = np.stack([alpha_mean, entropy, anisotropy], axis=-1)
sar_features = (sar_features - sar_features.min()) / (sar_features.max() - sar_features.min())




# 4. Convert to PyTorch tensor (THIS MUST COME BEFORE THE ViT CODE)
sar_tensor = torch.from_numpy(sar_features).permute(2, 0, 1).float()  # Shape: [3, H, W]

# 5. Now initialize ViT model
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")

# 6. Extract patches
inputs = processor(images=sar_tensor.unsqueeze(0), return_tensors="pt")  # Now sar_tensor is defined
with torch.no_grad():
    outputs = vit_model(**inputs)
patches = outputs.last_hidden_state  # Shape: [1, 197, 768]

print("Successfully extracted patches!")
print(f"Patch tensor shape: {patches.shape}")



## just for checking and testing 
# Add these right after tensor conversion
print("\n=== Data Shape Verification ===")
print(f"Original shapes: {alpha_mean.shape}, {entropy.shape}, {anisotropy.shape}")
print(f"Stacked shape: {sar_features.shape}")
print(f"Tensor shape: {sar_tensor.shape}")

# Verify normalization
print("\n=== Value Ranges ===")
print(f"Stacked min/max: {sar_features.min():.2f}, {sar_features.max():.2f}")
print(f"Tensor min/max: {sar_tensor.min():.2f}, {sar_tensor.max():.2f}")

# ViT input checks
print("\n=== ViT Input Checks ===")
print(f"Processor default size: {processor.size}")
print(f"Input tensor shape to ViT: {inputs['pixel_values'].shape}")







import matplotlib.pyplot as plt

# 1. Visualize stacked image BEFORE tensor conversion
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.imshow(sar_features)  # Auto-displays first 3 channels as RGB
plt.title("Stacked SAR Features (HxWx3)")
plt.axis('off')

# 2. Visualize AFTER tensor conversion (per-channel)
plt.subplot(122)
# Convert back to HxWxC for visualization
tensor_vis = sar_tensor.permute(1, 2, 0).numpy()
plt.imshow(tensor_vis)
plt.title("Tensor Format (3xHxW -> HxWx3)")
plt.axis('off')

plt.tight_layout()
plt.show()



In [None]:
# def classify_scattering(alpha, entropy, anisotropy):
#     if entropy < 0.3:
#         if alpha < 30: return "Z1: Low Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z2: Low Entropy Dipole Scatter"
#         else: return "Z3: Low Entropy Multiple Scatter"
#     elif 0.3 <= entropy < 0.7:
#         if alpha < 30: return "Z4: Medium Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z5: Medium Entropy Dipole Scatter"
#         else: return "Z6: Medium Entropy Multiple Scatter"
#     else:
#         if alpha < 30: return "Z7: High Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z8: High Entropy Dipole Scatter"
#         else: return "Z9: High Entropy Multiple Scatter"

In [None]:

# import numpy as np
# import matplotlib.pyplot as plt
# from ipywidgets import interact, interactive, fixed, widgets
# from IPython.display import display
# from transformers import pipeline  # NEW: For VQA

# # Your SAR data (example arrays - replace with your actual data)
# alpha_mean = np.random.rand(1000, 1000) * 90  # 0-90 degrees
# entropy = np.random.rand(1000, 1000)          # 0-1
# anisotropy = np.random.rand(1000, 1000)       # 0-1
# sar_features = np.stack([alpha_mean, entropy, anisotropy], axis=-1)

# def create_sar_composite(alpha, entropy, anisotropy):
#     # Normalize each channel with enhanced contrast
#     r = np.clip((alpha - np.percentile(alpha, 5)) / (np.percentile(alpha, 95) - np.percentile(alpha, 5)), 0, 1)
#     g = np.clip((entropy - np.percentile(entropy, 5)) / (np.percentile(entropy, 95) - np.percentile(entropy, 5)), 0, 1)
#     b = np.clip((anisotropy - np.percentile(anisotropy, 5)) / (np.percentile(anisotropy, 95) - np.percentile(anisotropy, 5)), 0, 1)
    
#     # Apply gamma correction for better visual contrast
#     composite = np.stack([
#         np.power(r, 0.6),  # Red channel (alpha)
#         np.power(g, 0.7),  # Green channel (entropy)
#         np.power(b, 0.8)   # Blue channel (anisotropy)
#     ], axis=-1)
    
#     return np.clip(composite, 0, 1)

# # Create the enhanced composite
# sar_composite = create_sar_composite(alpha_mean, entropy, anisotropy)

# # Display with proper scaling
# plt.figure(figsize=(10, 10))
# plt.imshow(sar_composite, aspect='auto')
# plt.colorbar(label='Normalized Intensity')
# plt.title('Enhanced SAR Composite (Red=Alpha, Green=Entropy, Blue=Anisotropy)')
# plt.show()


# # NEW: Precompute patches for faster interaction
# def precompute_patches(feature_map, patch_size=16):
#     return np.lib.stride_tricks.sliding_window_view(
#         feature_map, (patch_size, patch_size)
#     )[::patch_size, ::patch_size].mean(axis=(2,3))

# alpha_patches = precompute_patches(alpha_mean)  # NEW
# entropy_patches = precompute_patches(entropy)   # NEW
# anisotropy_patches = precompute_patches(anisotropy)  # NEW

# # NEW: Initialize VQA pipeline
# qa_pipeline = pipeline("question-answering", 
#                       model="deepset/roberta-base-squad2")

# def classify_scattering(alpha, entropy, anisotropy):
#     """Classify scattering mechanism into zones"""
#     if entropy < 0.3:
#         if alpha < 30: return "Z1: Low Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z2: Low Entropy Dipole Scatter"
#         else: return "Z3: Low Entropy Multiple Scatter"
#     elif 0.3 <= entropy < 0.7:
#         if alpha < 30: return "Z4: Medium Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z5: Medium Entropy Dipole Scatter"
#         else: return "Z6: Medium Entropy Multiple Scatter"
#     else:
#         if alpha < 30: return "Z7: High Entropy Surface Scatter"
#         elif 30 <= alpha <= 50: return "Z8: High Entropy Dipole Scatter"
#         else: return "Z9: High Entropy Multiple Scatter"


# def analyze_region(x=0, y=0, question=""):
#     patch_size = 16
#     x, y = min(x, alpha_mean.shape[1]-16), min(y, alpha_mean.shape[0]-16)
    
#     # Use precomputed patches (faster)
#     patch_x, patch_y = x//16, y//16
#     patch_alpha = alpha_patches[patch_y, patch_x]
#     patch_entropy = entropy_patches[patch_y, patch_x]
#     patch_anisotropy = anisotropy_patches[patch_y, patch_x]
    
#     zone = classify_scattering(patch_alpha, patch_entropy, patch_anisotropy)
    
#     fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18,5))
    
#     # 1. Actual SAR patch with red border
#     ax1.imshow(sar_composite[y:y+patch_size, x:x+patch_size])
#     rect = plt.Rectangle(
#         (0, 0),           # Bottom-left corner (relative to patch)
#         patch_size,        # Width
#         patch_size,        # Height
#         linewidth=2,       # Border thickness
#         edgecolor='red',   # Border color
#         facecolor='none'   # Transparent fill
#     )
#     ax1.add_patch(rect)
#     ax1.set_title(f"SAR Patch ({x}-{x+patch_size}, {y}-{y+patch_size})")
    
#     # 2. Text with VQA
#     context = f"{zone}\nAlpha: {patch_alpha:.1f}°\nEntropy: {patch_entropy:.2f}\nAnisotropy: {patch_anisotropy:.2f}"
#     if question:
#         answer = qa_pipeline(question=question, context=context)
#         context += f"\n\nQ: {question}\nA: {answer['answer']}"
#     ax2.text(0.5, 0.5, context, ha='center', va='center', fontsize=10)
#     ax2.axis('off')
    
#     # 3. Enhanced H/Alpha plot
#     ax3.scatter(patch_alpha, patch_entropy, c='red', s=100)
#     ax3.set_xlim(0, 90); ax3.set_ylim(0, 1)
#     ax3.set_xlabel("Alpha Angle (°)"); ax3.set_ylabel("Entropy")
#     ax3.axhline(0.3, color='red', linestyle='--', alpha=0.5)
#     ax3.axhline(0.7, color='red', linestyle='--', alpha=0.5)
#     ax3.axvline(30, color='blue', linestyle='--', alpha=0.5)
#     ax3.axvline(50, color='blue', linestyle='--', alpha=0.5)
#     ax3.grid(True)
    
#     plt.tight_layout()
#     plt.show()
#     return zone

# # Create interactive widgets (unchanged)
# x_slider = widgets.IntSlider(
#     min=0, 
#     max=alpha_mean.shape[1]-16, 
#     step=16, 
#     value=0,
#     description='X Position:'
# )
# y_slider = widgets.IntSlider(
#     min=0, 
#     max=alpha_mean.shape[0]-16, 
#     step=16, 
#     value=0,
#     description='Y Position:'
# )
# question_box = widgets.Text(
#     placeholder='Ask about this region...',
#     description='Question:'
# )

# # Create interactive UI (unchanged)
# ui = widgets.VBox([
#     widgets.HBox([x_slider, y_slider]),
#     question_box
# ])

# out = widgets.interactive_output(
#     analyze_region,
#     {'x': x_slider, 'y': y_slider, 'question': question_box}
# )

# display(ui, out)







import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from IPython.display import display, clear_output
import ipywidgets as widgets
from transformers import pipeline

# Ensure we're using the correct backend for interactivity in Jupyter
%matplotlib widget

# Precompute patches if not already done
def ensure_valid_dimensions(feature_map, patch_size=16):
    # Get current dimensions
    h, w = feature_map.shape
    
    # Calculate padding needed
    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size
    
    # Pad if needed
    if pad_h > 0 or pad_w > 0:
        return np.pad(feature_map, ((0, pad_h), (0, pad_w)), mode='constant')
    return feature_map

def precompute_patches(feature_map, patch_size=16):
    padded_map = ensure_valid_dimensions(feature_map, patch_size)
    return np.lib.stride_tricks.sliding_window_view(
        padded_map, (patch_size, patch_size)
    )[::patch_size, ::patch_size].mean(axis=(2,3))

# Create SAR composite if not already done
def create_sar_composite(alpha, entropy, anisotropy):
    # Normalize each channel with enhanced contrast
    r = np.clip((alpha - np.percentile(alpha, 5)) / (np.percentile(alpha, 95) - np.percentile(alpha, 5)), 0, 1)
    g = np.clip((entropy - np.percentile(entropy, 5)) / (np.percentile(entropy, 95) - np.percentile(entropy, 5)), 0, 1)
    b = np.clip((anisotropy - np.percentile(anisotropy, 5)) / (np.percentile(anisotropy, 95) - np.percentile(anisotropy, 5)), 0, 1)
    
    # Apply gamma correction for better visual contrast
    composite = np.stack([
        np.power(r, 0.6),  # Red channel (alpha)
        np.power(g, 0.7),  # Green channel (entropy)
        np.power(b, 0.8)   # Blue channel (anisotropy)
    ], axis=-1)
    
    return np.clip(composite, 0, 1)

# Precompute patches
alpha_patches = precompute_patches(alpha_mean)
entropy_patches = precompute_patches(entropy)
anisotropy_patches = precompute_patches(anisotropy)

# Create composite
sar_composite = create_sar_composite(alpha_mean, entropy, anisotropy)

# Initialize QA pipeline
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")

# Set up classification function
def classify_scattering(alpha, entropy, anisotropy):
    """Classify scattering mechanism into zones"""
    if entropy < 0.3:
        if alpha < 30: return "Z1: Low Entropy Surface Scatter"
        elif 30 <= alpha <= 50: return "Z2: Low Entropy Dipole Scatter"
        else: return "Z3: Low Entropy Multiple Scatter"
    elif 0.3 <= entropy < 0.7:
        if alpha < 30: return "Z4: Medium Entropy Surface Scatter"
        elif 30 <= alpha <= 50: return "Z5: Medium Entropy Dipole Scatter"
        else: return "Z6: Medium Entropy Multiple Scatter"
    else:
        if alpha < 30: return "Z7: High Entropy Surface Scatter"
        elif 30 <= alpha <= 50: return "Z8: High Entropy Dipole Scatter"
        else: return "Z9: High Entropy Multiple Scatter"

# Create a text widget for questions
question_input = widgets.Text(
    placeholder='Ask about this region...',
    description='Question:',
    layout=widgets.Layout(width='80%')
)

# Create an output widget for the analysis results
analysis_output = widgets.Output()

# Interactive visualization system
class SARImageAnalyzer:
    def __init__(self):
        self.patch_size = 16
        self.selected_x = 0
        self.selected_y = 0
        self.rect = None
        self.setup_ui()
        
    def setup_ui(self):
        # Create main figure for the composite image (larger size)
        self.fig, self.ax = plt.subplots(figsize=(12, 12))
        self.img = self.ax.imshow(sar_composite, aspect='auto')
        plt.colorbar(self.img, label='Normalized Intensity')
        self.ax.set_title('SAR Composite (Red=Alpha, Green=Entropy, Blue=Anisotropy)\nClick to select a region')
        
        # Connect the click event
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
    
    def on_click(self, event):
        if event.inaxes != self.ax:
            return
            
        # Get clicked coordinates (rounded to nearest patch boundary)
        self.selected_x = int(event.xdata // self.patch_size * self.patch_size)
        self.selected_y = int(event.ydata // self.patch_size * self.patch_size)
        
        # Ensure within bounds
        self.selected_x = min(self.selected_x, sar_composite.shape[1] - self.patch_size)
        self.selected_y = min(self.selected_y, sar_composite.shape[0] - self.patch_size)
        
        # Highlight selection with rectangle
        if self.rect:
            self.rect.remove()
        self.rect = plt.Rectangle(
            (self.selected_x, self.selected_y),
            self.patch_size, self.patch_size,
            linewidth=3, edgecolor='white', facecolor='none'
        )
        self.ax.add_artist(self.rect)
        self.fig.canvas.draw_idle()
        
        # Update analysis (this will use the question from the input widget)
        with analysis_output:
            clear_output(wait=True)
            self.analyze_region(question_input.value)
    
    def analyze_region(self, question=""):
        patch_x, patch_y = self.selected_x // self.patch_size, self.selected_y // self.patch_size
        patch_alpha = alpha_patches[patch_y, patch_x]
        patch_entropy = entropy_patches[patch_y, patch_x]
        patch_anisotropy = anisotropy_patches[patch_y, patch_x]
        
        zone = classify_scattering(patch_alpha, patch_entropy, patch_anisotropy)
        
        # Create analysis figure (bigger than before)
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))
        
        # 1. SAR patch with red border
        ax1.imshow(sar_composite[self.selected_y:self.selected_y+self.patch_size, 
                                 self.selected_x:self.selected_x+self.patch_size], 
                   interpolation='bilinear')  # Added interpolation for smoother appearance
        rect = plt.Rectangle(
            (0, 0),
            self.patch_size, self.patch_size,
            linewidth=2, edgecolor='red', facecolor='none'
        )
        ax1.add_patch(rect)
        ax1.set_title(f"SAR Patch ({self.selected_x}-{self.selected_x+self.patch_size}, "
                      f"{self.selected_y}-{self.selected_y+self.patch_size})")
        
        # 2. Text with classification and VQA
        context = f"{zone}\n\nAlpha: {patch_alpha:.1f}°\nEntropy: {patch_entropy:.2f}\nAnisotropy: {patch_anisotropy:.2f}"
        if question:
            answer = qa_pipeline(question=question, context=context)
            context += f"\n\nQ: {question}\nA: {answer['answer']}"
        ax2.text(0.5, 0.5, context, ha='center', va='center', fontsize=14)
        ax2.axis('off')
        
        # 3. Enhanced H/Alpha plot with zone boundaries
        ax3.scatter(patch_alpha, patch_entropy, c='red', s=150)
        ax3.set_xlim(0, 90)
        ax3.set_ylim(0, 1)
        ax3.set_xlabel("Alpha Angle (°)", fontsize=12)
        ax3.set_ylabel("Entropy", fontsize=12)
        ax3.axhline(0.3, color='red', linestyle='--', alpha=0.5)
        ax3.axhline(0.7, color='red', linestyle='--', alpha=0.5)
        ax3.axvline(30, color='blue', linestyle='--', alpha=0.5)
        ax3.axvline(50, color='blue', linestyle='--', alpha=0.5)
        
        # Add zone labels
        ax3.text(15, 0.15, "Z1", fontsize=10)
        ax3.text(40, 0.15, "Z2", fontsize=10)
        ax3.text(70, 0.15, "Z3", fontsize=10)
        ax3.text(15, 0.5, "Z4", fontsize=10)
        ax3.text(40, 0.5, "Z5", fontsize=10)
        ax3.text(70, 0.5, "Z6", fontsize=10)
        ax3.text(15, 0.85, "Z7", fontsize=10)
        ax3.text(40, 0.85, "Z8", fontsize=10)
        ax3.text(70, 0.85, "Z9", fontsize=10)
        
        ax3.grid(True)
        
        plt.tight_layout()
        plt.show()

# Initialize the analyzer
analyzer = SARImageAnalyzer()

# Update analysis when question changes (without requiring a new click)
def on_question_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with analysis_output:
            clear_output(wait=True)
            analyzer.analyze_region(change['new'])

question_input.observe(on_question_change, names='value')

# Display widgets
display(question_input, analysis_output)


