In [1]:
!jupyter nbconvert --ClearMetadataPreprocessor.enabled=True \
    --to notebook --inplace alpha_earth_app.ipynb


[NbConvertApp] Converting notebook alpha_earth_app.ipynb to notebook
[NbConvertApp] Writing 58235 bytes to alpha_earth_app.ipynb


# Chunk 1. Header

In [7]:
# ✅ Set your Earth Engine project ID once and reuse everywhere
PROJECT_ID = "animated-rhythm-449415-u3"

import ee

try:
    ee.Initialize(project=PROJECT_ID)
    print("✅ EE already authenticated")
except Exception:
    ee.Authenticate(auth_mode='notebook')
    ee.Initialize(project=PROJECT_ID)
    print("✅ EE authenticated & initialized")


✅ EE already authenticated


In [3]:
# ===============================================
# 📋 BINDER WELCOME MESSAGE
# ===============================================
print("🛰️ Welcome to AlphaEarth Land Cover Classifier on Binder!")
print("📋 Click 'Run All' to start the complete app")
print("⚠️ Earth Engine may have limited access in Binder")
print("💡 For full features, use Google Colab instead")
print("🚀 Loading app...")

🛰️ Welcome to AlphaEarth Land Cover Classifier on Binder!
📋 Click 'Run All' to start the complete app
⚠️ Earth Engine may have limited access in Binder
💡 For full features, use Google Colab instead
🚀 Loading app...


In [4]:
# ===============================================
# LOAD DEPENDENCIES
# ===============================================

import ee
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import folium
import math
import os
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
from IPython.display import display, clear_output, HTML
import ipywidgets as widgets

# Install and import ipyleaflet for interactive map
try:
    import ipyleaflet
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "ipyleaflet"])
    import ipyleaflet

from ipyleaflet import Map, DrawControl, basemaps, WidgetControl

# --- Optional imports for advanced algorithms ---
def _get_xgb():
    try:
        import xgboost as xgb
    except Exception:
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "xgboost"])
        import xgboost as xgb
    return xgb

def _get_lgb():
    try:
        import lightgbm as lgb
    except Exception:
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "lightgbm"])
        import lightgbm as lgb
    return lgb



# Chunk 3. App Section

In [8]:
# ===============================================
# ORIGINAL WORKING FUNCTIONS (from your working code)
# ===============================================

# Color palettes for embedding visualization (from working code)
EMBEDDING_PALETTES = {
    'embedding_1': ['#000080', '#0066CC', '#00CCFF', '#66FFCC', '#CCFF66', '#FFCC00', '#FF6600', '#CC0000'],  # Blue to Red
    'embedding_2': ['#2D0066', '#6600CC', '#9966FF', '#CC99FF', '#FFCCFF', '#FFCC99', '#FF9966', '#FF6633'],  # Purple to Orange
    'embedding_3': ['#004D00', '#009900', '#33CC33', '#66FF66', '#99FF99', '#CCFF33', '#FFFF00', '#FFB300']   # Green to Yellow
}

def sanitize_bbox(bbox):
    """From working code - robust bbox handling"""
    minLon, minLat, maxLon, maxLat = map(float, bbox)

    # Clamp to world bounds
    def clamp(v, lo, hi): return max(lo, min(hi, v))
    minLon = ((minLon + 180) % 360) - 180
    maxLon = ((maxLon + 180) % 360) - 180
    minLat = clamp(minLat, -85, 85)   # Web Mercator-friendly
    maxLat = clamp(maxLat, -85, 85)

    # If lats reversed, swap
    if minLat > maxLat:
        minLat, maxLat = maxLat, minLat

    # If longitudes equal and tiny width, give a small pad
    if abs(maxLon - minLon) < 1e-6:
        maxLon = min(180.0, minLon + 0.01)

    parts = []
    if minLon <= maxLon:
        parts = [[minLon, minLat, maxLon, maxLat]]
    else:
        # Crosses the antimeridian: split into two boxes
        parts = [
            [minLon, minLat, 180.0, maxLat],
            [-180.0, minLat, maxLon, maxLat]
        ]

    center_lon = math.fsum([p[0] + p[2] for p in parts]) / (2 * len(parts))
    center_lat = (minLat + maxLat) / 2.0
    return {"parts": parts, "center": [center_lat, center_lon]}

def region_from_parts(parts):
    """From working code"""
    rects = [ee.Geometry.Rectangle(p) for p in parts]
    geom = rects[0] if len(rects) == 1 else ee.Geometry.MultiPolygon(rects).dissolve()
    return geom

def load_stack(year=2020):
    """From working code"""
    emb_all = (ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')
               .filterDate(f'{year}-01-01', f'{year+1}-01-01')
               .mosaic())
    bands = emb_all.bandNames()               # e.g., A00..A63 from server
    emb = emb_all.select(bands)
    emb_mask = emb.select(0).mask()           # valid (land) mask
    esa = ee.Image("ESA/WorldCover/v100/2020").select('Map').rename('label')
    return emb, bands, emb_mask, esa

def add_ee_layer(folium_map, ee_image_object, vis_params, name):
    """From working code - exactly as it was"""
    map_id_dict = ee.Image(ee_image_object).getMapId(vis_params)
    folium.raster_layers.TileLayer(
        tiles=map_id_dict['tile_fetcher'].url_format,
        attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
        name=name,
        overlay=True,
        control=True
    ).add_to(folium_map)

def normalize_embedding(embedding_image, region):
    """From working code - EXACTLY as it was working"""
    # Get min/max values for the embedding in the region
    stats = embedding_image.reduceRegion(
        reducer=ee.Reducer.minMax(),
        geometry=region,
        scale=1000,  # Use coarser scale for statistics to speed up
        maxPixels=1e9
    )

    # Get the min and max values (handle potential None values)
    band_name = embedding_image.bandNames().get(0)
    min_val = ee.Number(stats.get(ee.String(band_name).cat('_min'))).max(-10)  # Clamp extreme values
    max_val = ee.Number(stats.get(ee.String(band_name).cat('_max'))).min(10)   # Clamp extreme values

    # Normalize to 0-1 range
    normalized = embedding_image.subtract(min_val).divide(max_val.subtract(min_val)).clamp(0, 1)
    return normalized

def top_k_importance(importances, bands, k=5, as_percent=True):
    """From working code"""
    vals = np.asarray(importances, dtype=float).ravel()
    bands = list(bands)
    L = min(vals.size, len(bands))
    vals, bands = vals[:L], bands[:L]
    display = [f"A{int(b[1:]) + 1:02d}" for b in bands]  # 'A00' -> 'A01'
    if as_percent:
        total = vals.sum() or 1.0
        vals = 100.0 * vals / total
    order = np.argsort(vals)[::-1][:k]
    return [(display[i], float(vals[i]), bands[i]) for i in order]  # Include original band name

# ===============================================
# 📱 APP CLASS DEFINITION WITH INTERACTIVE MAP
# ===============================================

class AlphaEarthApp:
    def __init__(self, project_id):
        self.project_id = project_id
        self.setup_constants()
        self.create_widgets()
        self.analysis_results = None
        self.selected_bbox = None

    def setup_constants(self):
        """Set up all constants and mappings"""
        self.ESA_CLASSES = {
            10: "Tree cover", 20: "Shrubland", 30: "Grassland", 40: "Cropland",
            50: "Built-up", 60: "Bare/sparse", 70: "Snow/ice", 80: "Water",
            90: "Herb. wetland", 95: "Mangroves", 100: "Moss/lichen"
        }

        self.CLASS_COLORS = {
            10: '#006400', 20: '#ffbb22', 30: '#ffff4c', 40: '#f096ff',
            50: '#fa0000', 60: '#b4b4b4', 70: '#f0f0f0', 80: '#0064c8',
            90: '#0096a0', 95: '#00cf75', 100: '#fae6a0'
        }

    def create_widgets(self):
        """Create all UI widgets with interactive map for region selection"""
        # App title
        self.title = widgets.HTML(
            value="""
            <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                        color: white; border-radius: 10px; margin-bottom: 20px;'>
                <h1 style='margin: 0; font-size: 28px;'>🛰️ Google AlphaEarth Land Cover Classifier</h1>
                <p style='margin: 10px 0 0 0; font-size: 14px; opacity: 0.9;'>
                    Advanced satellite embedding analysis for land cover classification
                </p>
            </div>
            """
        )

        # Region selection with interactive map
        self.region_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>📍 Interactive Region Selection</h3>"
        )

        self.region_help = widgets.HTML(
            value="""
            <div style='background: #d1ecf1; padding: 15px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #0ea5e9; color: #2c3e50; font-size: 14px;'>
                <b style='color: #0c4a6e;'>🗺️ How to select your region:</b><br>
                <span style='color: #374151;'>1️⃣ Click the rectangle tool (📦) in the map toolbar</span><br>
                <span style='color: #374151;'>2️⃣ Click and drag on the map to draw your region</span><br>
                <span style='color: #374151;'>3️⃣ Use the trash tool (🗑️) to delete rectangles</span><br>
                <span style='color: #374151;'>4️⃣ Only the most recent rectangle will be used</span>
            </div>
            """
        )

        # Create interactive map with drawing tools
        self.selection_map = Map(
            basemap=basemaps.Esri.WorldImagery,
            center=[41.8, -72.6],  # Default center
            zoom=7,
            layout=widgets.Layout(height='350px', width='100%')
        )

        # Add drawing control
        self.draw_control = DrawControl(
            marker={},
            circle={},
            circlemarker={},
            polyline={},
            polygon={},
            rectangle={
                "shapeOptions": {
                    "fillColor": "red",
                    "color": "red",
                    "fillOpacity": 0.3,
                    "weight": 3
                }
            }
        )

        self.draw_control.on_draw(self.handle_draw)
        self.selection_map.add_control(self.draw_control)

        # Region info display
        self.region_info = widgets.HTML()
        self.region_status = widgets.HTML(
            value="""
            <div style='background: #fff3cd; padding: 10px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #f59e0b; color: #2c3e50; font-size: 14px;'>
                <b style='color: #92400e;'>⚠️ No region selected</b><br>
                <span style='color: #374151;'>Please draw a rectangle on the map above to select your region of interest.</span>
            </div>
            """
        )

        # Class selection
        self.class_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>🌍 Land Cover Classes</h3>"
        )

        class_options = [(f"{k} - {v}", k) for k, v in self.ESA_CLASSES.items()]

        self.class_a = widgets.Dropdown(
            options=class_options, value=10, description='Class A:',
            style={'description_width': 'initial'}
        )

        self.class_b = widgets.Dropdown(
            options=class_options, value=80, description='Class B:',
            style={'description_width': 'initial'}
        )

        self.class_comparison = widgets.HTML()

        # Algorithm settings
        self.algo_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>🤖 Algorithm Settings</h3>"
        )

        self.algorithm = widgets.Dropdown(
            options=[('Random Forest', 'rf'), ('Gradient Boosting', 'gbt'),
                    ('XGBoost', 'xgb'), ('LightGBM', 'lgb')],
            value='rf', description='Algorithm:', style={'description_width': 'initial'}
        )

        self.test_size = widgets.IntSlider(
            value=25, min=10, max=40, step=5, description='Test Data %:',
            style={'description_width': 'initial'}
        )

        self.n_samples = widgets.IntSlider(
            value=100, min=50, max=300, step=25, description='Samples/class:',
            style={'description_width': 'initial'}
        )

        self.scale_m = widgets.IntSlider(
            value=500, min=100, max=2000, step=100, description='Scale (m):',
            style={'description_width': 'initial'}
        )

        self.seed = widgets.IntText(value=42, description='Seed:', style={'description_width': 'initial'})

        # Run button (initially disabled)
        self.run_button = widgets.Button(
            description='🚫 Draw a region first',
            button_style='',
            disabled=True,
            layout=widgets.Layout(width='300px', height='50px'),
        )

        # Progress bar (initially hidden)
        self.progress = widgets.IntProgress(
            value=0,
            min=0,
            max=100,
            description='Progress:',
            bar_style='info',
            style={'bar_color': '#3b82f6', 'description_width': 'initial'},
            layout=widgets.Layout(width='400px', margin='10px 0', visibility='hidden')
        )

        # Status and output
        self.status = widgets.HTML()
        self.output_area = widgets.Output()

        # Bind events
        self.run_button.on_click(self.run_analysis)

        # Auto-update class comparison
        for widget in [self.class_a, self.class_b]:
            widget.observe(self.update_class_comparison, names='value')

        # Initial updates
        self.update_class_comparison(None)

    def handle_draw(self, target, action, geo_json):
        """Handle rectangle drawing on the map"""
        if action == 'created' and geo_json['geometry']['type'] == 'Polygon':
            # Extract coordinates from the drawn rectangle
            coords = geo_json['geometry']['coordinates'][0]

            # Get bounding box
            lons = [coord[0] for coord in coords]
            lats = [coord[1] for coord in coords]

            bbox = [min(lons), min(lats), max(lons), max(lats)]
            self.selected_bbox = bbox

            # Update region info
            self.update_region_info()

            # Enable run button
            self.run_button.disabled = False
            self.run_button.description = '🚀 RUN ANALYSIS'
            self.run_button.button_style = 'primary'

            print(f"✅ Region selected! Bounds: {[round(b, 3) for b in bbox]}")

        elif action == 'deleted':
            # Reset when rectangle is deleted
            self.selected_bbox = None
            self.region_info.value = ""
            self.region_status.value = """
            <div style='background: #fff3cd; padding: 10px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #f59e0b; color: #2c3e50; font-size: 14px;'>
                <b style='color: #92400e;'>⚠️ No region selected</b><br>
                <span style='color: #374151;'>Please draw a rectangle on the map above to select your region of interest.</span>
            </div>
            """

            # Disable run button
            self.run_button.disabled = True
            self.run_button.description = '🚫 Draw a region first'
            self.run_button.button_style = ''

    def update_region_info(self):
        """Update region information display"""
        if self.selected_bbox:
            bbox = self.selected_bbox
            width = abs(bbox[2] - bbox[0])
            height = abs(bbox[3] - bbox[1])
            area = width * height

            self.region_info.value = f"""
            <div style='background: #f0f9ff; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #3b82f6; color: #1e293b; font-size: 14px;'>
                <b style='color: #1e40af; font-size: 15px;'>📏 Selected Region:</b><br>
                <span style='color: #374151; font-weight: 500;'>Longitude: {bbox[0]:.3f} to {bbox[2]:.3f} ({width:.3f}°)</span><br>
                <span style='color: #374151; font-weight: 500;'>Latitude: {bbox[1]:.3f} to {bbox[3]:.3f} ({height:.3f}°)</span><br>
                <span style='color: #374151; font-weight: 500;'>Area: {area:.4f} square degrees</span>
            </div>
            """

            self.region_status.value = ""  # Clear the "no region" message

    def update_class_comparison(self, change):
        """Update class comparison display with better styling"""
        try:
            class_a_name = self.ESA_CLASSES[self.class_a.value]
            class_b_name = self.ESA_CLASSES[self.class_b.value]

            # Calculate difficulty
            difficulty = self.calculate_class_similarity(self.class_a.value, self.class_b.value)
            if difficulty > 0.6:
                diff_label = "🔴 Hard"
                diff_color = "#dc2626"
            elif difficulty > 0.4:
                diff_label = "🟡 Medium"
                diff_color = "#d97706"
            else:
                diff_label = "🟢 Easy"
                diff_color = "#059669"

            self.class_comparison.value = f"""
            <div style='background: #e8f5e8; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #10b981; color: #1e293b; font-size: 14px;'>
                <b style='color: #047857; font-size: 15px;'>🆚 Comparison:</b><br>
                <span style='color: #374151; font-weight: 500;'>
                    {self.class_a.value} ({class_a_name}) vs {self.class_b.value} ({class_b_name})
                </span><br>
                <span style='color: {diff_color}; font-weight: bold;'>Classification difficulty: {diff_label}</span>
            </div>
            """
        except:
            self.class_comparison.value = ""

    def calculate_class_similarity(self, class_a, class_b):
        """Calculate how similar two classes are (affects classification difficulty)"""
        similarity_scores = {
            (10, 20): 0.8, (20, 30): 0.7, (30, 40): 0.6, (90, 95): 0.9, (10, 95): 0.5,
            (20, 100): 0.4, (60, 100): 0.5, (50, 60): 0.4, (50, 80): 0.1, (70, 80): 0.15,
            (10, 50): 0.2, (10, 80): 0.2, (50, 70): 0.1, (40, 80): 0.25
        }

        key = tuple(sorted([class_a, class_b]))
        return similarity_scores.get(key, 0.35)

    def display_app(self):
        """Display the complete app interface with interactive map"""

        # Add custom CSS for better visibility
        custom_css = widgets.HTML("""
        <style>
            .widget-html-content {
                color: #1e293b !important;
            }
            .widget-text input, .widget-dropdown select {
                color: #1e293b !important;
                background-color: white !important;
            }
            .widget-label {
                color: #374151 !important;
                font-weight: 500 !important;
            }
        </style>
        """)

        app_layout = widgets.VBox([
            custom_css,
            self.title,

            # Interactive Region Selection
            self.region_title,
            self.region_help,
            self.selection_map,
            self.region_status,
            self.region_info,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            # Class Section
            self.class_title,
            widgets.HBox([self.class_a, self.class_b]),
            self.class_comparison,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            # Algorithm Section
            self.algo_title,
            widgets.HBox([self.algorithm, self.test_size]),
            widgets.HBox([self.n_samples, self.scale_m]),
            self.seed,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            # Run Section
            widgets.VBox([
                self.run_button,
                self.progress,
                self.status
            ], layout=widgets.Layout(align_items='center')),

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            # Output Section
            self.output_area
        ])

        display(app_layout)

    def ee_pairwise_sample_global(self, class_a, class_b, *, year, region, scale_m, n_per_class, seed):
        """Sample data from Earth Engine - from working code"""
        emb, bands, emb_mask, esa = load_stack(year)
        pair_mask = esa.eq(class_a).Or(esa.eq(class_b))
        stack = (emb.updateMask(emb_mask).updateMask(pair_mask)
                ).addBands(esa.updateMask(emb_mask).updateMask(pair_mask))

        pair_fc = stack.stratifiedSample(
            numPoints=n_per_class * 2,
            classBand='label',
            region=region,
            scale=scale_m,
            classValues=[class_a, class_b],
            classPoints=[n_per_class, n_per_class],
            seed=seed,
            geometries=True
        ).select(bands.cat(ee.List(['label'])))

        size = pair_fc.size().getInfo()
        hist = pair_fc.aggregate_histogram('label').getInfo()

        props = bands.cat(ee.List(['label']))
        attr_fc = pair_fc.map(lambda f: ee.Feature(None, f.toDictionary(props)))
        feats = attr_fc.getInfo()['features'] if size else []
        df = pd.DataFrame([f['properties'] for f in feats]) if feats else pd.DataFrame(columns=[*bands.getInfo(), 'label'])
        return df, list(bands.getInfo()), hist, size, pair_fc

    def build_model(self, name, random_state=42):
        """Build ML model - from working code"""
        name = name.lower()
        if name == "rf":
            return RandomForestClassifier(n_estimators=300, random_state=random_state, n_jobs=-1)
        elif name == "gbt":
            return GradientBoostingClassifier(random_state=random_state)
        elif name == "lgb":
            lgb = _get_lgb()
            return lgb.LGBMClassifier(
                n_estimators=300, max_depth=6, learning_rate=0.1,
                subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
                objective='binary', metric='binary_logloss',
                n_jobs=-1, random_state=random_state, verbose=-1
            )
        elif name == "xgb":
            xgb = _get_xgb()
            return xgb.XGBClassifier(
                n_estimators=300, max_depth=6, learning_rate=0.1,
                subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
                tree_method="hist", objective="binary:logistic",
                eval_metric="logloss", n_jobs=-1, random_state=random_state
            )
        else:
            raise ValueError("CLASSIFIER must be 'rf', 'gbt', 'lgb', or 'xgb'")

    def run_analysis(self, button):
        """Main analysis function using the WORKING approach"""
        with self.output_area:
            clear_output(wait=True)

            # Check if region is selected
            if self.selected_bbox is None:
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        ❌ Please draw a region on the map first!
                    </span>
                </div>
                """
                return

            # Show progress bar and update status
            self.progress.layout.visibility = 'visible'
            self.progress.value = 10
            self.progress.description = 'Starting...'

            self.status.value = """
            <div style='background: #dbeafe; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #3b82f6; text-align: center;'>
                <span style='color: #1e40af; font-weight: bold; font-size: 16px;'>
                    🛰️ Processing satellite embeddings and generating analysis...
                </span>
            </div>
            """

            # Validation
            if self.class_a.value == self.class_b.value:
                self.progress.layout.visibility = 'hidden'
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        ❌ Please select different classes!
                    </span>
                </div>
                """
                return

            try:
                # Use the selected bbox from the map
                bbox = self.selected_bbox

                # 1) Build robust ROI geometry - USING WORKING CODE
                roi_info = sanitize_bbox(bbox)
                region = region_from_parts(roi_info["parts"])
                center_lat, center_lon = roi_info["center"]

                print("Using ROI bbox:", bbox)

                self.progress.value = 30
                self.progress.description = 'Loading satellite data...'

                # 2) Sample - USING WORKING CODE
                df, EMBED_BANDS, hist, size, pair_fc = self.ee_pairwise_sample_global(
                    self.class_a.value, self.class_b.value,
                    year=2020,
                    region=region,
                    scale_m=self.scale_m.value,
                    n_per_class=self.n_samples.value,
                    seed=self.seed.value
                )
                print(f"Sample size: {size}  |  class hist: {hist}")

                self.progress.value = 50
                self.progress.description = 'Training model...'

                # 3) Check availability
                counts = df['label'].value_counts().to_dict() if len(df) else {}
                have_a = counts.get(self.class_a.value, 0)
                have_b = counts.get(self.class_b.value, 0)
                MIN_PER_CLASS_REQUIRED = max(20, int(0.2 * self.n_samples.value))

                if have_a < MIN_PER_CLASS_REQUIRED or have_b < MIN_PER_CLASS_REQUIRED:
                    print(f"⚠️ Pairwise comparison NOT possible in this ROI.")
                    print(f"   Got A={have_a} (class {self.class_a.value}), B={have_b} (class {self.class_b.value}).")
                    print(f"   Try enlarging the drawn region, increasing scale, or choosing different classes.")
                    self.progress.layout.visibility = 'hidden'
                    self.status.value = """
                    <div style='background: #fef3c7; padding: 12px; border-radius: 8px; margin: 10px 0;
                                border-left: 4px solid #f59e0b; text-align: center;'>
                        <span style='color: #92400e; font-weight: bold; font-size: 16px;'>
                            ⚠️ Insufficient data - try larger region or different classes
                        </span>
                    </div>
                    """
                    return

                # 4) Train/test split and model - USING WORKING CODE
                label_map = {self.class_a.value: 0, self.class_b.value: 1}
                df = df[df['label'].isin(label_map.keys())].copy()
                y = df['label'].map(label_map).astype(int).values
                X = df[EMBED_BANDS].astype(float).values

                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=self.test_size.value/100, random_state=self.seed.value, stratify=y
                )

                clf = self.build_model(self.algorithm.value, random_state=self.seed.value)
                clf.fit(X_train, y_train)
                y_proba = clf.predict_proba(X_test)[:, 1]
                y_pred  = (y_proba >= 0.5).astype(int)

                self.progress.value = 70
                self.progress.description = 'Calculating metrics...'

                # 5) Metrics
                acc  = accuracy_score(y_test, y_pred)
                auc  = roc_auc_score(y_test, y_proba)
                report = classification_report(y_test, y_pred, output_dict=True)
                cm     = confusion_matrix(y_test, y_pred)

                # 6) Feature importance - USING WORKING CODE
                importances = clf.feature_importances_
                top5 = top_k_importance(importances, EMBED_BANDS, k=5, as_percent=True)
                top3 = top_k_importance(importances, EMBED_BANDS, k=3, as_percent=True)  # Get top 3 for mapping

                self.progress.value = 90
                self.progress.description = 'Creating map...'

                # 7) Load embedding data for visualization - USING WORKING CODE EXACTLY
                emb_full, _, emb_mask_vis, esa_vis = load_stack(2020)

                # Create individual embedding layers for the top 3 most important embeddings
                print(f"\nPreparing top 3 embedding layers for visualization...")
                embedding_layers = []

                for i, (display_name, importance_pct, original_band) in enumerate(top3):
                    print(f"  Processing {display_name} (importance: {importance_pct:.2f}%)...")

                    # Select the specific embedding band
                    embedding_band = emb_full.select(original_band).updateMask(emb_mask_vis).clip(region)

                    # Normalize the embedding for better visualization - USING WORKING FUNCTION
                    normalized_embedding = normalize_embedding(embedding_band, region)

                    # Store layer info
                    embedding_layers.append({
                        'image': normalized_embedding,
                        'name': display_name,
                        'importance': importance_pct,
                        'palette_key': f'embedding_{i+1}',
                        'rank': i + 1
                    })

                # 8) Map layers - USING WORKING CODE EXACTLY
                classA_img = esa_vis.eq(self.class_a.value).updateMask(esa_vis.eq(self.class_a.value)).updateMask(emb_mask_vis).clip(region)
                classB_img = esa_vis.eq(self.class_b.value).updateMask(esa_vis.eq(self.class_b.value)).updateMask(emb_mask_vis).clip(region)

                # Create Folium map
                m = folium.Map(location=[center_lat, center_lon], zoom_start=8, control_scale=True)

                # Add satellite basemap
                folium.TileLayer(
                    tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
                    attr='Google',
                    name='Google Satellite',
                    overlay=False,
                    control=True
                ).add_to(m)

                # Add Earth Engine layers - USING WORKING CODE
                color_a = self.CLASS_COLORS.get(self.class_a.value, '#FF3B30')
                color_b = self.CLASS_COLORS.get(self.class_b.value, '#007AFF')

                add_ee_layer(m, classA_img, {'palette': [color_a]}, f'Pixels: {self.class_a.value} ({self.ESA_CLASSES.get(self.class_a.value,"?")})')
                add_ee_layer(m, classB_img, {'palette': [color_b]}, f'Pixels: {self.class_b.value} ({self.ESA_CLASSES.get(self.class_b.value,"?")})')

                # Add all top 3 embedding layers - WITH DEBUGGING
                print(f"\nAdding all top 3 embedding layers to map (you can toggle them on/off in the layer control)...")

                for layer in embedding_layers:
                    try:
                        palette = EMBEDDING_PALETTES[layer['palette_key']]
                        layer_name = f"🏆 {layer['name']} ({layer['importance']:.1f}%)"

                        print(f"  🔍 Debugging layer: {layer_name}")

                        # Test the image before adding to map
                        test_stats = layer['image'].reduceRegion(
                            reducer=ee.Reducer.minMax(),
                            geometry=region,
                            scale=2000,
                            maxPixels=1e6
                        ).getInfo()
                        print(f"  📊 Layer stats: {test_stats}")

                        # Try to get map tiles
                        print(f"  🗺️ Getting map tiles...")
                        map_id_dict = ee.Image(layer['image']).getMapId({
                            'min': 0,
                            'max': 1,
                            'palette': palette
                        })
                        print(f"  ✅ Map tiles URL obtained: {map_id_dict['tile_fetcher'].url_format[:100]}...")

                        # Add to map
                        folium.raster_layers.TileLayer(
                            tiles=map_id_dict['tile_fetcher'].url_format,
                            attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
                            name=layer_name,
                            overlay=True,
                            control=True
                        ).add_to(m)

                        print(f"  ✅ Successfully added layer: {layer_name}")

                    except Exception as e:
                        print(f"  ❌ Failed to add layer {layer_name}: {str(e)}")
                        # Continue with next layer
                        continue




                print(f"Total embedding layers added: {len(embedding_layers)}")

                # Add sample points - USING WORKING CODE
                sample_points = pair_fc.getInfo()

                if 'features' in sample_points:
                    # Create separate FeatureGroups for each class
                    class_a_name = self.ESA_CLASSES.get(self.class_a.value, "?")
                    class_b_name = self.ESA_CLASSES.get(self.class_b.value, "?")

                    points_layer_a = folium.FeatureGroup(name=f'📍 Samples: {self.class_a.value} ({class_a_name})')
                    points_layer_b = folium.FeatureGroup(name=f'📍 Samples: {self.class_b.value} ({class_b_name})')

                    count_a = count_b = 0

                    for feature in sample_points['features']:
                        coords = feature['geometry']['coordinates']
                        label = feature['properties']['label']

                        if label == self.class_a.value:
                            folium.CircleMarker(
                                location=[coords[1], coords[0]],
                                radius=4,
                                popup=f'Class {label} ({class_a_name})',
                                color='white',
                                fillColor='black',
                                fillOpacity=1,
                                weight=1
                            ).add_to(points_layer_a)
                            count_a += 1
                        elif label == self.class_b.value:
                            folium.CircleMarker(
                                location=[coords[1], coords[0]],
                                radius=4,
                                popup=f'Class {label} ({class_b_name})',
                                color='black',
                                fillColor='white',
                                fillOpacity=1,
                                weight=1
                            ).add_to(points_layer_b)
                            count_b += 1

                    points_layer_a.add_to(m)
                    points_layer_b.add_to(m)

                    print(f"  Added {count_a} sample points for Class {self.class_a.value} ({class_a_name})")
                    print(f"  Added {count_b} sample points for Class {self.class_b.value} ({class_b_name})")

                # Add ROI boundary
                roi_coords = []
                for part in roi_info["parts"]:
                    minLon, minLat, maxLon, maxLat = part
                    roi_coords.append([
                        [minLat, minLon], [minLat, maxLon],
                        [maxLat, maxLon], [maxLat, minLon], [minLat, minLon]
                    ])

                roi_layer = folium.FeatureGroup(name='🗺️ ROI Boundary')
                for coords in roi_coords:
                    folium.Polygon(
                        locations=coords,
                        color='black',
                        weight=2,
                        fill=False,
                        popup='Selected ROI Boundary'
                    ).add_to(roi_layer)
                roi_layer.add_to(m)

                # Add layer control
                # Add layer control
                folium.LayerControl().add_to(m)

                # Add "Deselect All" button to the layer control panel
                deselect_all_script = '''
                <script>
                // Wait for map to load, then add the "Deselect All" button
                setTimeout(function() {
                    // Find the layer control
                    var layerControl = document.querySelector('.leaflet-control-layers');
                    if (layerControl) {
                        // Create the deselect all button
                        var deselectButton = document.createElement('div');
                        deselectButton.innerHTML = `
                            <button onclick="deselectAllOverlays()"
                                    style="width: 100%; margin: 5px 0; padding: 6px 8px;
                                           background: #ef4444; color: white; border: none;
                                           border-radius: 4px; cursor: pointer; font-size: 11px;
                                           font-weight: bold; text-align: center;">
                                🚫 Deselect All Overlays
                            </button>
                        `;

                        // Insert the button at the bottom of the overlays section
                        var overlaysSection = layerControl.querySelector('.leaflet-control-layers-overlays');
                        if (overlaysSection) {
                            overlaysSection.appendChild(deselectButton);
                            console.log('✅ Deselect All button added to layer control');
                        }
                    }
                }, 1000); // Wait 1 second for map to fully load

                // Function to deselect all overlay layers
                function deselectAllOverlays() {
                    var overlayInputs = document.querySelectorAll('.leaflet-control-layers-overlays input[type="checkbox"]');
                    var uncheckedCount = 0;

                    overlayInputs.forEach(function(input) {
                        if (input.checked) {
                            input.click(); // Properly trigger the layer hide
                            uncheckedCount++;
                        }
                    });

                    if (uncheckedCount > 0) {
                        console.log(`🚫 Deselected ${uncheckedCount} overlay layers`);
                        // Optional: Show brief notification
                        showNotification(`🚫 ${uncheckedCount} layers deselected`);
                    } else {
                        showNotification('ℹ️ No layers were selected');
                    }
                }

                // Simple notification function
                function showNotification(message) {
                    var notification = document.createElement('div');
                    notification.innerHTML = message;
                    notification.style.cssText = `
                        position: fixed; top: 20px; right: 20px; z-index: 10000;
                        background: #1f2937; color: white; padding: 10px 15px;
                        border-radius: 6px; font-size: 13px; font-weight: bold;
                        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
                        transition: opacity 0.3s ease;
                    `;
                    document.body.appendChild(notification);

                    // Auto-remove after 2 seconds
                    setTimeout(function() {
                        notification.style.opacity = '0';
                        setTimeout(function() {
                            if (notification.parentNode) {
                                notification.parentNode.removeChild(notification);
                            }
                        }, 300);
                    }, 2000);
                }
                </script>
                '''

                # Add the script to the map
                m.get_root().html.add_child(folium.Element(deselect_all_script))

                # Display results first
                self.display_metrics(acc, auc, report, cm, len(y_train), len(y_test), importances, EMBED_BANDS, top5)

                # Then display map
                print("💡 Use the layer control panel (📋 icon in top-right corner of map) to toggle all layers on/off")
                display(m)

                out_path = os.path.abspath("alphaearth_map.html")
                m.save(out_path)
                print(f"\nSaved interactive map to: {out_path}")

                self.status.value = """
                <div style='background: #dcfce7; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #16a34a; text-align: center;'>
                    <span style='color: #166534; font-weight: bold; font-size: 16px;'>
                        ✅ Analysis completed! Check results below.
                    </span>
                </div>
                """

                # Hide progress bar after completion
                self.progress.layout.visibility = 'hidden'

            except Exception as e:
                print(f"❌ Error: {str(e)}")
                self.progress.layout.visibility = 'hidden'  # Hide progress bar on error
                self.status.value = f"""
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        ❌ Error: {str(e)}
                    </span>
                </div>
                """

    def display_metrics(self, acc, auc, report, cm, n_train, n_test, importances, embed_bands, top5):
        """Display classification results"""

        # Create metrics display
        metrics_html = f"""
        <div style='display: flex; flex-wrap: wrap; gap: 15px; margin: 20px 0;'>
            <div style='background: #dcfce7; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #166534;'>🎯 Accuracy</h4>
                <h2 style='margin: 5px 0 0 0; color: #16a34a;'>{acc*100:.1f}%</h2>
            </div>
            <div style='background: #dbeafe; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #1e40af;'>📈 ROC AUC</h4>
                <h2 style='margin: 5px 0 0 0; color: #2563eb;'>{auc:.3f}</h2>
            </div>
            <div style='background: #f3e8ff; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #7c2d12;'>🔢 Train</h4>
                <h2 style='margin: 5px 0 0 0; color: #a855f7;'>{n_train}</h2>
            </div>
            <div style='background: #fed7d7; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #7c2d12;'>🔢 Test</h4>
                <h2 style='margin: 5px 0 0 0; color: #dc2626;'>{n_test}</h2>
            </div>
        </div>
        """

        display(HTML(metrics_html))

        print("\n=== Classification Results ===")
        print(f"📋 Per-class metrics:")
        print(f"   Class {self.class_a.value}: Precision={report['0']['precision']:.3f}, Recall={report['0']['recall']:.3f}, F1={report['0']['f1-score']:.3f}")
        print(f"   Class {self.class_b.value}: Precision={report['1']['precision']:.3f}, Recall={report['1']['recall']:.3f}, F1={report['1']['f1-score']:.3f}")

        print("\nConfusion matrix (rows=true, cols=pred):")
        cm_df = pd.DataFrame(cm, index=['A(0)','B(1)'], columns=['P0','P1'])
        print(cm_df)

        print(f"\n🏆 Top 5 Most Important Embeddings:")
        for i, (display_name, importance_pct, original_band) in enumerate(top5):
            print(f"   {i+1}. {display_name}: {importance_pct:.2f}%")

        # Feature importance plot - from working code
        vals = np.asarray(importances, dtype=float).ravel()
        bands = list(embed_bands)
        L = min(vals.size, len(bands))
        vals, bands = vals[:L], bands[:L]
        idx = np.array([int(b[1:]) for b in bands])      # 0..63
        order = np.argsort(idx)
        arr = vals[order]
        total = arr.sum() or 1.0
        arr = 100.0 * arr / total
        xlabels = [f"A{n+1:02d}" for n in idx[order]]    # A01..A64

        plt.figure(figsize=(14, 4))
        colors = ['yellow' if i in np.argsort(importances)[-3:] else 'steelblue' for i in order]
        plt.bar(range(len(arr)), arr, color=colors)
        plt.xticks(range(len(arr)), xlabels, rotation=90)
        plt.ylabel("Importance (%)")
        plt.title(f"Embedding importance: {self.class_a.value} vs {self.class_b.value} [{self.algorithm.value.upper()}] ( Yellow = Top 5 on map)")
        plt.tight_layout()
        plt.show()

# ===============================================
# 🚀 LAUNCH THE APP
# ===============================================

print("🛰️ Initializing Google AlphaEarth Land Cover Classifier...")
print("🔄 Installing ipyleaflet for interactive map...")

app = AlphaEarthApp(PROJECT_ID)
app.display_app()

print("\n✅ App ready!")
print("💡 Instructions:")
print("   1. Draw a rectangle on the map above to select your region")
print("   2. Choose your land cover classes to compare")
print("   3. Configure algorithm settings")
print("   4. Click 'RUN ANALYSIS' to start")

🛰️ Initializing Google AlphaEarth Land Cover Classifier...
🔄 Installing ipyleaflet for interactive map...


VBox(children=(HTML(value='\n        <style>\n            .widget-html-content {\n                color: #1e29…


✅ App ready!
💡 Instructions:
   1. Draw a rectangle on the map above to select your region
   2. Choose your land cover classes to compare
   3. Configure algorithm settings
   4. Click 'RUN ANALYSIS' to start


# Chunk 4. Requerements .txt Section