In [None]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Define fixed colors for tree types
tree_colors = cm.viridis(np.linspace(0, 1, 3))  # 3 colors for Spruce, Pine, Deciduous
tree_types = ["Spruce", "Pine", "Deciduous"]
wavelength_min = 0 # 0 µm
wavelength_max = 1 # 1 µm

# Define line styles for different satellites
satellite_styles = {
    "Sentinel": "-",
    "Landsat": "--",
    "Hypso": ":"  # Hypso needs nm to µm conversion
}

# Load all saved spectral data
spectral_data = {}
for file in os.listdir():
    if file.startswith("spectral_data_") and file.endswith(".pkl"):
        with open(file, "rb") as f:
            data = pickle.load(f)
            sensor_name = file.split("_")[2].split(".")[0]  # Extract sensor name
            spectral_data[sensor_name] = data

# Convert HYPSO wavelengths from nm to µm if present
if "Hypso" in spectral_data:
    spectral_data["Hypso"]["wavelengths"] /= 1000.0  # Convert nm to µm

# ---------------- PLOT 1: Sentinel & Landsat ----------------
"""plt.figure(figsize=(8, 5))
for sensor in ["Landsat"]:
    if sensor not in spectral_data:
        continue

    wavelengths = spectral_data[sensor]["wavelengths"]
    stats = spectral_data[sensor]["spectral_stats"]
    
    for i, (label, stat) in enumerate(stats.items()):
        if i >= len(tree_colors):  # Ensure we don't go out of color bounds
            continue
        
        plt.plot(
            wavelengths, stat["mean"], linestyle=satellite_styles[sensor],
            color=tree_colors[i], label=f"{sensor} - {tree_types[i]}"
        )

plt.xlabel("Wavelength (µm)")
plt.ylabel("Reflectance")
plt.title("Sentinel & Landsat - Spectral Signatures")
plt.grid(True)
plt.legend()
plt.savefig("spectral_comparison_landsat.png")
plt.show()"""

for sensor in ["Landsat"]:
    if sensor not in spectral_data:
        continue

    wavelengths = spectral_data[sensor]["wavelengths"]
    stats = spectral_data[sensor]["spectral_stats"]

    for i, (label, stat) in enumerate(stats.items()):
        if i >= len(tree_colors):  # Ensure we don't go out of color bounds
            continue
        
        # Plot the line
        plt.plot(
            wavelengths, stat["mean"], linestyle=satellite_styles[sensor],
            color=tree_colors[i], label=f"{sensor} - {tree_types[i]}"
        )
        
        # Add dots at Landsat wavelength points
        plt.scatter(wavelengths, stat["mean"], color=tree_colors[i], edgecolor='black', s=30, zorder=3)

plt.xlabel("Wavelength (µm)")
plt.ylabel("Reflectance")
plt.title("Landsat - Spectral Signatures")
plt.grid(True)
plt.legend()
plt.savefig("spectral_comparison_landsat.png")
plt.show()


# ---------------- PLOT 2: Sentinel, Landsat & HYPSO (Limited to HYPSO range) ----------------
plt.figure(figsize=(8, 5))
for sensor in ["Sentinel", "Landsat", "Hypso"]:
    if sensor not in spectral_data:
        continue

    wavelengths = np.array(spectral_data[sensor]["wavelengths"])
    
    # Filter Sentinel & Landsat to HYPSO wavelength range
    if sensor != "Hypso":
        valid_indices = (wavelengths >= wavelength_min) & (wavelengths <= wavelength_max)
        print(f"Valid indices for {sensor}: {valid_indices}")
        print(f"Wavelengths: {wavelengths}")
        wavelengths = wavelengths[valid_indices]

    stats = spectral_data[sensor]["spectral_stats"]
    
    for i, (label, stat) in enumerate(stats.items()):
        if i >= len(tree_colors):  # Ensure we don't go out of color bounds
            continue
        
        plt.plot(
            wavelengths, np.array(stat["mean"])[valid_indices] if sensor != "Hypso" else stat["mean"],
            linestyle=satellite_styles[sensor],
            color=tree_colors[i], label=f"{sensor} - {tree_types[i]}"
        )

plt.xlabel("Wavelength (µm)")
plt.ylabel("Reflectance")
plt.title("Sentinel, Landsat & HYPSO - Spectral Signatures")
plt.grid(True)
plt.legend()
plt.savefig("spectral_comparison_sentinel_landsat_hypso.png")
plt.show()
