<h1 style="text-align:center; color:blue; font-size:32px;">
  <b>On the top Menu click "Run" and then "Run All Cells"</b>
</h1>


<h2 style="text-align:center; font-size:26px;">
  📺 Watch <a href="https://www.youtube.com/watch?v=lC2y_dAHKvM" target="_blank">Tutorial Video</a>
</h2>


In [7]:
import ee
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import folium
import math
import os
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, confusion_matrix
from IPython.display import display, clear_output, HTML
import ipywidgets as widgets

# Install and import required packages
packages_to_install = ['ipyleaflet', 'gspread', 'google-auth']
for package in packages_to_install:
    try:
        if package == 'ipyleaflet':
            import ipyleaflet
        elif package == 'gspread':
            import gspread
        elif package == 'google-auth':
            import google.auth
    except ImportError:
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package])

from ipyleaflet import Map, DrawControl, basemaps, WidgetControl
import gspread
from google.oauth2.service_account import Credentials

# --- Optional imports for classifying algorithms ---
def _get_xgb():
    try:
        import xgboost as xgb
    except Exception:
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "xgboost"])
        import xgboost as xgb
    return xgb

def _get_lgb():
    try:
        import lightgbm as lgb
    except Exception:
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "lightgbm"])
        import lightgbm as lgb
    return lgb

# ===============================================
# PROJECT CONFIGURATION
# ===============================================
PROJECT_ID = "animated-rhythm-449415-u3"

# Google Sheets Configuration
GOOGLE_SHEET_NAME = "AlphaEarth User Interactions"
GOOGLE_SHEET_URL = "https://docs.google.com/spreadsheets/d/1XXhGbSc5_UwXmsuvhkJJnqdrilxgPGvo-t1AYTcP6T0/edit?gid=0#gid=0"

# ===============================================
# SIMPLE EARTH ENGINE AUTHENTICATION
# ===============================================
EE_INITIALIZED = False

print("Authenticating with Google Earth Engine...")
try:
    ee.Initialize()
    print("Earth Engine already initialized!")
    EE_INITIALIZED = True
except:
    try:
        print("Starting Earth Engine authentication...")
        print("Click the link below, sign in, and paste the code when prompted")
        ee.Authenticate()
        ee.Initialize()
        print("Earth Engine authenticated successfully!")
        EE_INITIALIZED = True
    except Exception as e:
        print(f"Earth Engine authentication failed: {str(e)}")
        print("Some features will be limited")
        EE_INITIALIZED = False

print("Welcome to AlphaEarth Land Cover Classifier!")
print("Loading app...")

# ===============================================
# ORIGINAL WORKING FUNCTIONS
# ===============================================

# Color palettes for embedding visualization
EMBEDDING_PALETTES = {
    'embedding_1': ['#000080', '#0066CC', '#00CCFF', '#66FFCC', '#CCFF66', '#FFCC00', '#FF6600', '#CC0000'],
    'embedding_2': ['#2D0066', '#6600CC', '#9966FF', '#CC99FF', '#FFCCFF', '#FFCC99', '#FF9966', '#FF6633'],
    'embedding_3': ['#004D00', '#009900', '#33CC33', '#66FF66', '#99FF99', '#CCFF33', '#FFFF00', '#FFB300']
}

def sanitize_bbox(bbox):
    """From working code - robust bbox handling"""
    minLon, minLat, maxLon, maxLat = map(float, bbox)

    # Clamp to world bounds
    def clamp(v, lo, hi): return max(lo, min(hi, v))
    minLon = ((minLon + 180) % 360) - 180
    maxLon = ((maxLon + 180) % 360) - 180
    minLat = clamp(minLat, -85, 85)
    maxLat = clamp(maxLat, -85, 85)

    # If lats reversed, swap
    if minLat > maxLat:
        minLat, maxLat = maxLat, minLat

    # If longitudes equal and tiny width, give a small pad
    if abs(maxLon - minLon) < 1e-6:
        maxLon = min(180.0, minLon + 0.01)

    parts = []
    if minLon <= maxLon:
        parts = [[minLon, minLat, maxLon, maxLat]]
    else:
        # Crosses the antimeridian: split into two boxes
        parts = [
            [minLon, minLat, 180.0, maxLat],
            [-180.0, minLat, maxLon, maxLat]
        ]

    center_lon = math.fsum([p[0] + p[2] for p in parts]) / (2 * len(parts))
    center_lat = (minLat + maxLat) / 2.0
    return {"parts": parts, "center": [center_lat, center_lon]}

def region_from_parts(parts):
    """From working code"""
    rects = [ee.Geometry.Rectangle(p) for p in parts]
    geom = rects[0] if len(rects) == 1 else ee.Geometry.MultiPolygon(rects).dissolve()
    return geom

def load_stack(year=2020):
    """From working code"""
    emb_all = (ee.ImageCollection('GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL')
               .filterDate(f'{year}-01-01', f'{year+1}-01-01')
               .mosaic())
    bands = emb_all.bandNames()
    emb = emb_all.select(bands)
    emb_mask = emb.select(0).mask()
    esa = ee.Image("ESA/WorldCover/v100/2020").select('Map').rename('label')
    return emb, bands, emb_mask, esa

def add_ee_layer(folium_map, ee_image_object, vis_params, name):
    """From working code - exactly as it was"""
    map_id_dict = ee.Image(ee_image_object).getMapId(vis_params)
    folium.raster_layers.TileLayer(
        tiles=map_id_dict['tile_fetcher'].url_format,
        attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
        name=name,
        overlay=True,
        control=True
    ).add_to(folium_map)

def normalize_embedding(embedding_image, region):
    """From working code - EXACTLY as it was working"""
    stats = embedding_image.reduceRegion(
        reducer=ee.Reducer.minMax(),
        geometry=region,
        scale=1000,
        maxPixels=1e9
    )

    band_name = embedding_image.bandNames().get(0)
    min_val = ee.Number(stats.get(ee.String(band_name).cat('_min'))).max(-10)
    max_val = ee.Number(stats.get(ee.String(band_name).cat('_max'))).min(10)

    normalized = embedding_image.subtract(min_val).divide(max_val.subtract(min_val)).clamp(0, 1)
    return normalized

def top_k_importance(importances, bands, k=5, as_percent=True):
    """From working code"""
    vals = np.asarray(importances, dtype=float).ravel()
    bands = list(bands)
    L = min(vals.size, len(bands))
    vals, bands = vals[:L], bands[:L]
    display = [f"A{int(b[1:]) + 1:02d}" for b in bands]
    if as_percent:
        total = vals.sum() or 1.0
        vals = 100.0 * vals / total
    order = np.argsort(vals)[::-1][:k]
    return [(display[i], float(vals[i]), bands[i]) for i in order]

def all_embedding_importance(importances, bands, as_percent=True):
    """Get importance for all 64 embeddings in A01-A64 order"""
    vals = np.asarray(importances, dtype=float).ravel()
    bands = list(bands)
    L = min(vals.size, len(bands))
    vals, bands = vals[:L], bands[:L]
    
    # Create mapping from band to importance
    band_to_importance = {}
    for i, band in enumerate(bands):
        display_name = f"A{int(band[1:]) + 1:02d}"
        band_to_importance[display_name] = vals[i]
    
    if as_percent:
        total = vals.sum() or 1.0
        for key in band_to_importance:
            band_to_importance[key] = 100.0 * band_to_importance[key] / total
    
    # Return all 64 in order A01-A64
    result = []
    for i in range(1, 65):
        embedding_name = f"A{i:02d}"
        importance = band_to_importance.get(embedding_name, 0.0)
        result.append((embedding_name, float(importance)))
    
    return result

# ===============================================
# GOOGLE SHEETS DATA LOGGING
# ===============================================

def setup_google_sheets():
    """Setup Google Sheets connection with embedded credentials"""
    try:
        import json
        from google.oauth2.service_account import Credentials
        import gspread
        
        # Your actual service account credentials
        service_account_info = {
            "type": "service_account",
            "project_id": "animated-rhythm-449415-u3",
            "private_key_id": "e4bfd44f69cc34c0d0f9b9bdd75586454f8dba61",
            "private_key": """-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC+9gG4K2D2FijJ
r48jPLZc1JIGkqnnZRoIq7U08KIk/p0bRku1qrdpGDeYwzi/a0DPJ2h0WtlnRLxF
kk4tieWECTh/R/CqBTwFNEGahh2grXRAAsLpvQPuCeyJwUQx9ZUS05dinmk2JuI3
xnvXbv1KE35wmPpCjSZQbLnWVmGo5doLemHc0YJmgPbvrqjyQmkjG/SKqF48pd3e
RQ34PE+avNPUADKazxXMx1ZcdYUg6IxHZ32CZ132f/2Rloq9nDckCRs8RN/DXSMl
1Z35IZNKegjc+sBo7MQfSFDQX66HLdPjQ79MhIN9fQW9sD+3P19eDf7EO7JJVH6+
QTpQz76DAgMBAAECggEAHCSQU06dTxcr8rFbZIONkrfHP4282XvWbuGzZnZyRwZ5
K+xbuw3KUwQ9KVfYFKnfVshFhiRfdu8z2gL6X467kw5y05tMFhqIYnH1myrZ0fTO
NaLSi/UNz7aJtDGV9dDuI+SoImIbOl5Xp4WqebEDRjgCQeDhvQ4lT+P+LV70rJnW
EkXsboIanXC1NUNXySKYTbptlWZPkB6BgYn1liOp10Ee3Tr/i+92dCMABztJGzHD
lMaew2riAfvy0LBlleBfo/M1HT0x+nyzDAGEnmFEPBR8IDu9StvVao4+edwbERrU
E6B/d2lWPsIblv4XV+JYp5TOTkB4TeXIF0iZ0Q/YIQKBgQDrC+pDPYdlWY8/brKH
zhHtbRn3yPgUd7BUiOW3Eb3fmkqvgcaiFeeRI/UpADlyLfoQIGp9yGASSKD4+Lif
3o5rPORr5CERFNaGm9ZEoffhpmAKrNAmUDUrnavL0STVG17fhG7kZBYDxsETEZwo
ZThqnsX35pM11jQM1H1DHB6BoQKBgQDP+/+HUcLQjjbq9uDhgMnNzVbjaLNMKBzy
QMCYxrvuE18fDPURYC5lWq9fHhkoZGPsB4w+txUiH2TVo73aKnLpx2TnvESkNzVK
82S6VrqD2Byw5LJSJU3EQYfslvF1XzHajjvkM0z3ORckl4uiOKgenS77gcHqlQZI
4aZbUyEVowKBgCWdaogIOLrHcl08YeXJ3KSaaSV2S/06ikV9AwYhGLnH/1vG8PrP
dpeLoZhGKOtU4EfkaCZpArbqWJh/dUxgWXpf7E432/LX2tz/43JoWTsLXSNYNJdt
DpVETlH3zmUsVCqBDx4NxlgkXpGmp87vI/AsKJJbnANpnDrRZppuzZMBAoGBALs/
mme/EZDoaMMT75S2eXli5FYq/jmTZ7Qm3f7t9WY7ZIk1BjJVkI/JnFgbZfT6pQds
KDSSamFlOEgrehK/4uNclBkaZAYgekd3Z/lnedaXrAUKcQ0J7sDBatlzcAIG6YBm
S7+A/EsbyB02nw02yw184D6pLOfAsX7OQRK5ffDXAoGBAMXRPhEx8rYatlBYAvuk
pZlzuDGuWSiwQ4BqHzCHHnXg1ywsSMsu4q5oG+ufgKpjvOOhylT1HDqh2KGI1jyf
6VKxzFGKIPwgOnxIzEaKy32v6D9+Sb9w8iGRA/5kkatbt7r9UyCbpkpgQqTpXyEG
yvpFfo7z7CLiM4dVRqyeMDcK
-----END PRIVATE KEY-----""",
            "client_email": "alphaearth-sheets-writer@animated-rhythm-449415-u3.iam.gserviceaccount.com",
            "client_id": "110103667740457401202",
            "auth_uri": "https://accounts.google.com/o/oauth2/auth",
            "token_uri": "https://oauth2.googleapis.com/token",
            "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
            "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/alphaearth-sheets-writer%40animated-rhythm-449415-u3.iam.gserviceaccount.com"
        }
        
        # Use full Drive and Sheets scopes now that both APIs are enabled
        scopes = [
            'https://www.googleapis.com/auth/spreadsheets',
            'https://www.googleapis.com/auth/drive'
        ]
        
        credentials = Credentials.from_service_account_info(service_account_info, scopes=scopes)
        gc = gspread.authorize(credentials)
        return gc
        
    except Exception as e:
        print(f"Google Sheets setup failed: {str(e)}")
        return None

def test_sheets_connection():
    """Test if Google Sheets connection works"""
    try:
        gc = setup_google_sheets()
        if gc is None:
            print("❌ Sheets client not configured")
            return False
        
        sheet = gc.open("AlphaEarth User Interactions").sheet1
        print("✅ Successfully connected to Google Sheets")
        
        # Test write
        sheet.append_row(["test", "connection", "working"])
        print("✅ Test data written successfully")
        return True
        
    except Exception as e:
        print(f"❌ Connection test failed: {str(e)}")
        return False

# ===============================================
# APP CLASS DEFINITION WITH GOOGLE SHEETS LOGGING
# ===============================================

class AlphaEarthApp:
    def __init__(self, project_id):
        self.project_id = project_id
        
        # Initialize instance variables FIRST
        self.analysis_results = None
        self.selected_bbox = None
        self.last_row_number = None
        
        # Set up constants and create widgets
        self.setup_constants()
        self.create_widgets()
        
        # Set up Google Sheets connection
        self.sheets_client = setup_google_sheets()

    def setup_constants(self):
        """Set up all constants and mappings"""
        self.ESA_CLASSES = {
            10: "Tree cover", 20: "Shrubland", 30: "Grassland", 40: "Cropland",
            50: "Built-up", 60: "Bare/sparse", 70: "Snow/ice", 80: "Water",
            90: "Herb. wetland", 95: "Mangroves", 100: "Moss/lichen",
            999: "All other classes"
        }

        self.CLASS_COLORS = {
            10: '#006400', 20: '#ffbb22', 30: '#ffff4c', 40: '#f096ff',
            50: '#fa0000', 60: '#b4b4b4', 70: '#f0f0f0', 80: '#0064c8',
            90: '#0096a0', 95: '#00cf75', 100: '#fae6a0',
            999: '#808080'
        }
        
        # Countries list for dropdown
        self.COUNTRIES = [
            "Select your current country",  # Default option
            "Afghanistan", "Albania", "Algeria", "Argentina", "Armenia", "Australia", 
            "Austria", "Azerbaijan", "Bahrain", "Bangladesh", "Belarus", "Belgium", 
            "Bolivia", "Bosnia and Herzegovina", "Brazil", "Bulgaria", "Cambodia", 
            "Canada", "Chile", "China", "Colombia", "Costa Rica", "Croatia", 
            "Czech Republic", "Denmark", "Ecuador", "Egypt", "Estonia", "Ethiopia", 
            "Finland", "France", "Georgia", "Germany", "Ghana", "Greece", "Guatemala", 
            "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", 
            "Israel", "Italy", "Japan", "Jordan", "Kazakhstan", "Kenya", "South Korea", 
            "Kuwait", "Latvia", "Lebanon", "Lithuania", "Luxembourg", "Malaysia", 
            "Mexico", "Morocco", "Netherlands", "New Zealand", "Nigeria", "Norway", 
            "Pakistan", "Panama", "Peru", "Philippines", "Poland", "Portugal", 
            "Qatar", "Romania", "Russia", "Saudi Arabia", "Serbia", "Singapore", 
            "Slovakia", "Slovenia", "South Africa", "Spain", "Sri Lanka", "Sweden", 
            "Switzerland", "Thailand", "Turkey", "Ukraine", "United Arab Emirates", 
            "United Kingdom", "United States", "Uruguay", "Venezuela", "Vietnam"
        ]

    def log_interaction_to_sheets(self, results_data):
        """Log user interaction data to Google Sheets with feedback as last column"""
        from datetime import datetime
        
        # Create timestamp in format: YYYY MM DD HH
        now = datetime.now()
        formatted_timestamp = f"{now.year} {now.month:02d} {now.day:02d} {now.hour:02d}"
        
        # Get current country from dropdown
        current_country = self.country_dropdown.value if self.country_dropdown.value != "Select your current country" else "Not selected"
        
        # Calculate ROI area
        bbox = self.selected_bbox
        width = abs(bbox[2] - bbox[0])  # max_lon - min_lon
        height = abs(bbox[3] - bbox[1])  # max_lat - min_lat
        area = width * height
        
        # Prepare base row data (33 columns)
        base_row_data = [
            formatted_timestamp,
            now.isoformat(),
            current_country,  # Store selected country
            bbox[0],  # min_lon
            bbox[1],  # min_lat
            bbox[2],  # max_lon
            bbox[3],  # max_lat
            results_data['center_lat'],
            results_data['center_lon'],
            width,  # roi_width_degrees
            height,  # roi_height_degrees
            area,  # roi_area_square_degrees
            self.class_a.value,
            self.ESA_CLASSES[self.class_a.value],
            self.class_b.value,
            self.ESA_CLASSES[self.class_b.value],
            self.algorithm.value,
            self.test_size.value,
            self.n_samples.value,
            self.scale_m.value,
            self.seed.value,
            results_data['accuracy'],
            results_data['roc_auc'],
            results_data['precision_a'],
            results_data['recall_a'],
            results_data['f1_a'],
            results_data['precision_b'],
            results_data['recall_b'],
            results_data['f1_b'],
            results_data['n_train'],
            results_data['n_test'],
            # Top 5 embeddings as separate columns (2 columns)
            results_data['top_embeddings'][0][0] if len(results_data['top_embeddings']) > 0 else '',
            results_data['top_embeddings'][0][1] if len(results_data['top_embeddings']) > 0 else 0,
        ]
        
        # Add top 4 more embeddings (8 more columns)
        for i in range(1, 5):
            if i < len(results_data['top_embeddings']):
                base_row_data.extend([
                    results_data['top_embeddings'][i][0],
                    results_data['top_embeddings'][i][1]
                ])
            else:
                base_row_data.extend(['', 0])
        
        # Add all 64 embedding names and importances (128)       
        all_embedding_data = []
        if 'all_embeddings' in results_data:
            for embedding_name, importance in results_data['all_embeddings']:
                all_embedding_data.append(embedding_name)  # Add name
                all_embedding_data.append(importance)      # Add importance
        else:
            # Fallback: fill with empty names and zero importances
            for i in range(1, 65):
                all_embedding_data.append(f'A{i:02d}')  # Add name
                all_embedding_data.append(0.0)          # Add importance
        
        # Combine: base_data (41 cols) + all_embedding_data (128 cols) = 169 columns
        # Feedback will be column 170 (the last column)
        row_data = base_row_data + all_embedding_data + [""]  # Empty feedback placeholder
        
        try:
            if self.sheets_client is None:
                return
            
            # Open the Google Sheet
            sheet = self.sheets_client.open(GOOGLE_SHEET_NAME).sheet1
            
            # Check if header exists, if not create it
            try:
                existing_data = sheet.get_all_values()
                if len(existing_data) == 0:
                    # Create header row - BASE HEADERS (41 columns)
                    base_headers = [
                        'timestamp', 'timestamp_iso', 'current_country',
                        'min_lon', 'min_lat', 'max_lon', 'max_lat',
                        'center_lat', 'center_lon', 'roi_width_degrees', 'roi_height_degrees', 'roi_area_square_degrees',
                        'class_a', 'class_a_name', 'class_b', 'class_b_name',
                        'algorithm', 'test_size_percent', 'samples_per_class', 'scale_meters', 'seed',
                        'accuracy', 'roc_auc', 'precision_class_a', 'recall_class_a', 'f1_class_a',
                        'precision_class_b', 'recall_class_b', 'f1_class_b', 'n_train', 'n_test',
                        'top_embedding_1', 'top_embedding_1_pct', 'top_embedding_2', 'top_embedding_2_pct',
                        'top_embedding_3', 'top_embedding_3_pct', 'top_embedding_4', 'top_embedding_4_pct',
                        'top_embedding_5', 'top_embedding_5_pct'
                    ]
                    
                    # Add all 64 embedding name and importance columns (128 columns)
                    embedding_headers = []
                    for i in range(1, 65):
                        embedding_headers.append(f'A{i:02d}_name')
                        embedding_headers.append(f'A{i:02d}_importance')
                    
                    # Add feedback header as THE LAST COLUMN (1 column)
                    feedback_headers = ['user_feedback']
                    
                    # Total: 41 + 128 + 1 = 170 columns
                    headers = base_headers + embedding_headers + feedback_headers
                    sheet.append_row(headers)
            except:
                pass
            
            # Add the new interaction data
            sheet.append_row(row_data)
            
            # Store the row number for later feedback update
            current_rows = len(sheet.get_all_values())
            self.last_row_number = current_rows
            
        except Exception as e:
            print(f"Failed to log to Google Sheets: {str(e)}")
            print("Data logging disabled for this session")

    def update_feedback_in_sheets(self, feedback_text):
        """Update the most recent row with feedback data in the LAST column"""
        try:
            if self.sheets_client is None or self.last_row_number is None:
                return
            
            sheet = self.sheets_client.open(GOOGLE_SHEET_NAME).sheet1
            
            # Calculate the feedback column (THE ABSOLUTE LAST COLUMN)
            # Base columns (41) + embedding columns (128) + feedback (1) = 170 total columns
            # So feedback is in column 170
            feedback_column = 170
            
            # Update the feedback cell in the LAST column
            sheet.update_cell(self.last_row_number, feedback_column, feedback_text)
            
        except Exception as e:
            print(f"Failed to update feedback in Google Sheets: {str(e)}")

    def ee_pairwise_sample_with_others(self, class_a, class_b, *, year, region, scale_m, n_per_class, seed):
        """Sample data from Earth Engine - handles 'all other classes' option"""
        emb, bands, emb_mask, esa = load_stack(year)
        
        # All valid ESA classes
        all_esa_classes = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100]
        
        # Handle "All other classes" scenarios
        if class_a == 999 and class_b == 999:
            raise ValueError("Both classes cannot be 'All other classes'")
        
        elif class_a == 999:
            # Class A = "all others", Class B = specific
            specific_class = class_b
            other_classes = [c for c in all_esa_classes if c != specific_class]
            
            # Create combined mask for all "other" classes
            others_mask = esa.eq(other_classes[0])
            for other_class in other_classes[1:]:
                others_mask = others_mask.Or(esa.eq(other_class))
            
            specific_mask = esa.eq(specific_class)
            pair_mask = others_mask.Or(specific_mask)
            
            # Relabel: all "others" become 999, specific stays as is
            relabeled_esa = esa.where(others_mask, 999)
            
        elif class_b == 999:
            # Class B = "all others", Class A = specific
            specific_class = class_a
            other_classes = [c for c in all_esa_classes if c != specific_class]
            
            # Create combined mask for all "other" classes
            others_mask = esa.eq(other_classes[0])
            for other_class in other_classes[1:]:
                others_mask = others_mask.Or(esa.eq(other_class))
            
            specific_mask = esa.eq(specific_class)
            pair_mask = others_mask.Or(specific_mask)
            
            # Relabel: all "others" become 999, specific stays as is
            relabeled_esa = esa.where(others_mask, 999)
            
        else:
            # Normal case: both classes are specific
            pair_mask = esa.eq(class_a).Or(esa.eq(class_b))
            relabeled_esa = esa

        # Create stack with processed labels
        stack = (emb.updateMask(emb_mask).updateMask(pair_mask)
                ).addBands(relabeled_esa.updateMask(emb_mask).updateMask(pair_mask).rename('label'))

        pair_fc = stack.stratifiedSample(
            numPoints=n_per_class * 2,
            classBand='label',
            region=region,
            scale=scale_m,
            classValues=[class_a, class_b],
            classPoints=[n_per_class, n_per_class],
            seed=seed,
            geometries=True
        ).select(bands.cat(ee.List(['label'])))

        size = pair_fc.size().getInfo()
        hist = pair_fc.aggregate_histogram('label').getInfo()

        props = bands.cat(ee.List(['label']))
        attr_fc = pair_fc.map(lambda f: ee.Feature(None, f.toDictionary(props)))
        feats = attr_fc.getInfo()['features'] if size else []
        df = pd.DataFrame([f['properties'] for f in feats]) if feats else pd.DataFrame(columns=[*bands.getInfo(), 'label'])
        return df, list(bands.getInfo()), hist, size, pair_fc

    def create_widgets(self):
        """Create all UI widgets with interactive map for region selection"""
        # App title
        self.title = widgets.HTML(
            value="""
            <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                        color: white; border-radius: 10px; margin-bottom: 20px;'>
                <h1 style='margin: 0; font-size: 28px;'>Google AlphaEarth Foundations App for Land Cover Classification</h1>
                <p style='margin: 10px 0 0 0; font-size: 16px; font-style: italic; opacity: 0.9;'>
                     🎵🎶 Will you still need me, will you still feed me, when I'm Sixty-four? 🎵🎶
                </p>
            </div>
            """
        )

        # Welcome description box with data collection notice
        self.welcome_box = widgets.HTML(
            value=f"""
            <div style='background: #f8f9ff; padding: 20px; border-radius: 10px; margin-bottom: 20px;
                        border-left: 4px solid #4f46e5; color: #1e293b;'>
                <p style='margin: 0 0 15px 0; font-size: 16px; line-height: 1.6;'>
                    <span style='color: #dc2626; font-weight: bold;'>Welcome, please read before running the App:</span> 
                    this app uses the <strong><a href="https://deepmind.google/discover/blog/alphaearth-foundations-helps-map-our-planet-in-unprecedented-detail/" target="_blank" style='color: #2563eb; text-decoration: none;'>Google AlphaEarth AI-Foundation Model</a></strong> for Land Cover Classification tasks.
                    You will interact face to face with the 64 embeddings of this model (year 2020) to classify 
                    land cover types using validated label data from the Ecological Society of America (2020). This is a unique opportunity 
                    to explore how abstract embeddings from a cutting-edge Google DeepMind foundation model capture ecological patterns 
                    and to identify which embeddings are most important for specific pairwise classification tasks.
                </p>
                
                
                
                <p style='margin: 0 0 10px 0; font-size: 16px; font-weight: 500; color: #1e40af;'>
                    For example, you can compare:
                </p>
                <ul style='margin: 0 0 15px 20px; font-size: 15px; line-height: 1.5; color: #2563eb;'>
                    <li style='color: #2563eb; font-weight: 500;'><strong>Built-up vs Water</strong></li>
                    <li style='color: #2563eb; font-weight: 500;'><strong>Cropland vs Grassland</strong></li>
                    <li style='color: #2563eb; font-weight: 500;'><strong>Shrubland vs Tree cover </strong></li>
                     <li style='color: #2563eb; font-weight: 500;'><strong>or select among 11 land cover types (55 possible combinations) </strong></li>
                </ul>
                
                <p style='margin: 0 0 15px 0; font-size: 16px; line-height: 1.6; color: #1e293b;'>
                    and determine which embeddings play the most significant role in those pairwise distinctions. To support 
                    your analysis, the app provides four classifier options: Random Forest, XGBoost, Gradient Boosting, and LightGBM, 
                    plus other control parameters. You'll gain intuitive insights into how powerful this foundation model can be for your machine learning tasks
                    in geospatial science.
                </p>
                
                <p style='margin: 0 0 15px 0; font-size: 16px; line-height: 1.6; color: #1e293b;'>
                    <strong style='color: #7c2d12;'>🌍 Research Data Collection:</strong> Each time you run a classification, 
                    your interaction data (region coordinates, class selections, algorithm parameters, and model performance) 
                    will be automatically stored for research purposes from a Citizen Science perspective. Our main objectives are twofolded: 1) to 
                    create a large crowdsourced dataset to identify relationships between land cover types and specific embedding 
                    patterns; and 2) help people around the world to get familiar with Google Alpha Earth through intuitive and understandable classification tasks.
                    No personal information will be registered, all interactions are 100% anonymous.
                </p>
                
                <p style='margin: 0 0 15px 0; font-size: 16px; line-height: 1.6; color: #1e293b;'>                    
                    🌍 To run this app with all features (not just demo),  You must create and register a Google Cloud Project under the non-commercial 
                    option, which automatically enables the Earth Engine API. Once your account and project are set up, the app can connect 
                    securely to Earth Engine, moving beyond demo mode and giving you full access to datasets and geospatial analysis tools.                   
                </p>
                
                <p style='margin: 0 0 15px 0; font-size: 16px; line-height: 1.6; color: #1e293b;'>                    
                    🌍 Watch a tutorial video here for detailed explanation before running the app
                    <a href="https://www.youtube.com/watch?v=lC2y_dAHKvM" target="_blank" style="color: #2563eb; text-decoration: underline; font-weight: 500;">https://www.youtube.com/watch?v=lC2y_dAHKvM</a>.                   
                </p>
                
                <div style='background: #e0f2fe; padding: 15px; border-radius: 8px; border-left: 4px solid #0288d1;'>
                    <h4 style='margin: 0 0 10px 0; font-size: 16px; font-weight: bold; color: #01579b;'>
                        Instructions:
                    </h4>
                    <ol style='margin: 0; font-size: 15px; line-height: 1.6; color: #1e293b;'>
                        <li style='margin-bottom: 6px;'><strong>Select your current country</strong></li>
                        <li style='margin-bottom: 6px;'><strong>Select a Region of Interest (ROI)</strong> drawing a rectangle on the map</li>
                        <li style='margin-bottom: 6px;'><strong>Select land cover classes</strong> for classification</li>
                        <li style='margin-bottom: 6px;'><strong>Select Algorithm Settings</strong> (type, % of test data, number of samples per class, spatial scale and seed)</li>
                        <li style='margin-bottom: 6px;'><strong>Hit the RUN ANALYSIS button</strong></li>
                        <li style='margin-bottom: 6px;'><strong>Check results</strong></li>
                        <li style='margin-bottom: 6px;'><strong>Optional: download embedding rasters</strong></li>
                        <li style='margin-bottom: 6px;'><strong>Provide feedback</strong> about your experience</li>
                    </ol>
                </div>
            </div>
            """
        )

        # Attribution box
        self.attribution_box = widgets.HTML(
            value="""
            <div style='background: #f1f5f9; padding: 20px; border-radius: 8px; margin-bottom: 20px;
                        text-align: center; border-left: 4px solid #64748b;'>
                <div style='display: flex; justify-content: center; align-items: center; gap: 20px; margin-bottom: 15px;'>
                    <a href="https://sdslab.io/" target="_blank" style="text-decoration: none;">
                        <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/b/bb/NU_RGB_seal_R.png/250px-NU_RGB_seal_R.png" 
                             alt="Northeastern University" style="height: 60px; width: auto; cursor: pointer; transition: opacity 0.3s; border-radius: 4px;"
                             onmouseover="this.style.opacity='0.8'" onmouseout="this.style.opacity='1'">
                    </a>
                    <a href="https://www.gmri.org/" target="_blank" style="text-decoration: none;">
                        <img src="https://upload.wikimedia.org/wikipedia/en/9/90/Gulf_of_Maine_Research_Institute_logo.png" 
                             alt="Gulf of Maine Research Institute" style="height: 60px; width: auto; cursor: pointer; transition: opacity 0.3s; border-radius: 4px;"
                             onmouseover="this.style.opacity='0.8'" onmouseout="this.style.opacity='1'">
                    </a>
                </div>
                <p style='margin: 0; font-size: 14px; color: #475569; line-height: 1.5;'>
                    This App was created by <strong>Felipe Benavides</strong> from the SDS Lab at<br>
                    <strong>Northeastern University</strong> and the <strong>Gulf of Maine Research Institute</strong><br>
                    <a href="mailto:pipeben@gmail.com" style="color: #0ea5e9; text-decoration: none; font-weight: 500;">pipeben@gmail.com</a>
                </p>
                
                <div style='border-top: 1px solid #cbd5e1; margin: 15px 0 5px 0; padding-top: 15px;'>
                    <p style='margin: 0 0 8px 0; font-size: 13px; font-weight: bold; color: #374151;'>
                        How to Cite This App:
                    </p>
                    <p style='margin: 0; font-size: 12px; color: #64748b; line-height: 1.4; font-style: italic;'>
                        Benavides, F. (2025). AlphaEarth Foundation Model for Land Cover Classification App. 
                        SDS Lab, Northeastern University and Gulf of Maine Research Institute. 
                        <em>Zenodo</em>. <a href="https://doi.org/10.5281/zenodo.16911104" target="_blank" style="color: #2563eb; text-decoration: none;">https://doi.org/10.5281/zenodo.16911104</a>
                    </p>
                </div>
            </div>
            """
        )

        # Current country selection
        self.country_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>User Information</h3>"
        )

        self.country_dropdown = widgets.Dropdown(
            options=self.COUNTRIES,
            value="Select your current country",
            description='What country are you now?',
            style={'description_width': '170px'},
            layout=widgets.Layout(width='500px')
        )

        # Region selection with interactive map
        self.region_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>Region of Interest (ROI) selection</h3>"
        )

        self.region_help = widgets.HTML(
            value="""
            <div style='background: #d1ecf1; padding: 15px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #0ea5e9; color: #2c3e50; font-size: 14px;'>
                <b style='color: #0c4a6e;'>How to select your region:</b><br>
                <span style='color: #374151;'>1. Click the rectangle tool in the map toolbar</span><br>
                <span style='color: #374151;'>2. Click and drag on the map to draw your region</span><br>
                <span style='color: #374151;'>3. Use the trash tool to delete rectangles</span><br>                
            </div>
            """
        )

        # Create interactive map with drawing tools
        self.selection_map = Map(
            basemap=basemaps.Esri.WorldImagery,
            center=[41.8, -72.6],
            zoom=7,
            scroll_wheel_zoom=True,
            layout=widgets.Layout(height='350px', width='100%')
        )

        # Add drawing control
        self.draw_control = DrawControl(
            marker={},
            circle={},
            circlemarker={},
            polyline={},
            polygon={},
            rectangle={
                "shapeOptions": {
                    "fillColor": "red",
                    "color": "red",
                    "fillOpacity": 0.3,
                    "weight": 3
                }
            }
        )

        self.draw_control.on_draw(self.handle_draw)
        self.selection_map.add_control(self.draw_control)

        # Region info display
        self.region_info = widgets.HTML()
        self.region_status = widgets.HTML(
            value="""
            <div style='background: #fff3cd; padding: 10px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #f59e0b; color: #2c3e50; font-size: 14px;'>
                <b style='color: #92400e;'>No region selected</b><br>
                <span style='color: #374151;'>Please draw a rectangle on the map above to select your region of interest.</span>
            </div>
            """
        )

        # Class selection
        self.class_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>Land Cover Classes</h3>"
        )

        class_options = [(f"{k} - {v}", k) for k, v in self.ESA_CLASSES.items()]

        self.class_a = widgets.Dropdown(
            options=class_options, value=10, description='Class A:',
            style={'description_width': 'initial'}
        )

        self.class_b = widgets.Dropdown(
            options=class_options, value=80, description='Class B:',
            style={'description_width': 'initial'}
        )

        self.class_comparison = widgets.HTML()

        # Algorithm settings
        self.algo_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 20px 0 10px 0;'>Algorithm Settings</h3>"
        )

        self.algorithm = widgets.Dropdown(
            options=[('Random Forest', 'rf'), ('Gradient Boosting', 'gbt'),
                    ('XGBoost', 'xgb'), ('LightGBM', 'lgb')],
            value='rf', description='Algorithm:', style={'description_width': 'initial'}
        )

        self.test_size = widgets.IntSlider(
            value=25, min=10, max=40, step=5, description='Test Data %:',
            style={'description_width': 'initial'}
        )

        self.n_samples = widgets.IntSlider(
            value=100, min=50, max=300, step=25, description='Samples/class:',
            style={'description_width': 'initial'}
        )

        self.scale_m = widgets.IntSlider(
            value=500, min=100, max=2000, step=100, description='Scale (m):',
            style={'description_width': 'initial'}
        )

        self.seed = widgets.IntText(value=42, description='Seed:', style={'description_width': 'initial'})

        # Run button (initially disabled)
        self.run_button = widgets.Button(
            description='Select country and draw region first',
            button_style='',
            disabled=True,
            layout=widgets.Layout(width='350px', height='50px'),
        )

        # Progress bar (initially hidden)
        self.progress = widgets.IntProgress(
            value=0,
            min=0,
            max=100,
            description='Progress:',
            bar_style='info',
            style={'bar_color': '#3b82f6', 'description_width': 'initial'},
            layout=widgets.Layout(width='400px', margin='10px 0', visibility='hidden')
        )

        # Status and output
        self.status = widgets.HTML()
        self.output_area = widgets.Output()

        # Feedback section (initially hidden)
        self.feedback_title = widgets.HTML(
            "<h3 style='color: #2c3e50; margin: 30px 0 10px 0;'>Feedback</h3>"
        )

        self.feedback_question = widgets.HTML(
            value="""
            <div style='background: #f0fdf4; padding: 15px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #16a34a; color: #1e293b;'>
                <h4 style='margin: 0 0 10px 0; color: #166534; font-size: 16px;'>
                    How has this app helped you to get familiar with Google AlphaEarth embeddings?
                </h4>
                <p style='margin: 0; font-size: 13px; color: #374151;'>
                    
                </p>
            </div>
            """
        )

        self.feedback_textarea = widgets.Textarea(
            value='',
            placeholder='Share your experience with the AlphaEarth embeddings...',
            description='Your feedback:',
            style={'description_width': '100px'},
            layout=widgets.Layout(width='100%', height='120px')
        )

        self.feedback_button = widgets.Button(
            description='Submit Feedback',
            button_style='success',
            layout=widgets.Layout(width='200px', height='40px'),
            disabled=False
        )

        self.feedback_status = widgets.HTML()

        self.feedback_section = widgets.VBox([
            self.feedback_title,
            self.feedback_question,
            self.feedback_textarea,
            widgets.HBox([self.feedback_button], layout=widgets.Layout(justify_content='center')),
            self.feedback_status
        ], layout=widgets.Layout(visibility='hidden', margin='20px 0'))

        # Bind events
        self.run_button.on_click(self.run_analysis)
        self.feedback_button.on_click(self.submit_feedback)

        # Auto-update class comparison and button state
        for widget in [self.class_a, self.class_b]:
            widget.observe(self.update_class_comparison, names='value')
        
        self.country_dropdown.observe(self.update_button_state, names='value')

        # Initial updates
        self.update_class_comparison(None)
        self.update_button_state(None)

    def update_button_state(self, change):
        """Update run button state based on country selection and region drawing"""
        country_selected = self.country_dropdown.value != "Select your current country"
        region_selected = self.selected_bbox is not None
        
        if country_selected and region_selected:
            self.run_button.disabled = False
            self.run_button.description = 'RUN ANALYSIS'
            self.run_button.button_style = 'primary'
        elif not country_selected:
            self.run_button.disabled = True
            self.run_button.description = 'Select country first'
            self.run_button.button_style = ''
        else:
            self.run_button.disabled = True
            self.run_button.description = 'Draw a region on map'
            self.run_button.button_style = ''

    def handle_draw(self, target, action, geo_json):
        """Handle rectangle drawing on the map"""
        if action == 'created' and geo_json['geometry']['type'] == 'Polygon':
            coords = geo_json['geometry']['coordinates'][0]
            lons = [coord[0] for coord in coords]
            lats = [coord[1] for coord in coords]

            bbox = [min(lons), min(lats), max(lons), max(lats)]
            self.selected_bbox = bbox

            self.update_region_info()
            self.update_button_state(None)

            print(f"Region selected! Bounds: {[round(b, 3) for b in bbox]}")

        elif action == 'deleted':
            self.selected_bbox = None
            self.region_info.value = ""
            self.region_status.value = """
            <div style='background: #fff3cd; padding: 10px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #f59e0b; color: #2c3e50; font-size: 14px;'>
                <b style='color: #92400e;'>No region selected</b><br>
                <span style='color: #374151;'>Please draw a rectangle on the map above to select your region of interest.</span>
            </div>
            """
            self.update_button_state(None)

    def update_region_info(self):
        """Update region information display"""
        if self.selected_bbox:
            bbox = self.selected_bbox
            width = abs(bbox[2] - bbox[0])
            height = abs(bbox[3] - bbox[1])
            area = width * height

            self.region_info.value = f"""
            <div style='background: #f0f9ff; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #3b82f6; color: #1e293b; font-size: 14px;'>
                <b style='color: #1e40af; font-size: 15px;'>Selected Region:</b><br>
                <span style='color: #374151; font-weight: 500;'>Longitude: {bbox[0]:.3f} to {bbox[2]:.3f} ({width:.3f}°)</span><br>
                <span style='color: #374151; font-weight: 500;'>Latitude: {bbox[1]:.3f} to {bbox[3]:.3f} ({height:.3f}°)</span><br>
                <span style='color: #374151; font-weight: 500;'>Area: {area:.4f} square degrees</span>
            </div>
            """

            self.region_status.value = ""

    def update_class_comparison(self, change):
        """Update class comparison display with better styling"""
        try:
            if self.class_a.value == 999 and self.class_b.value == 999:
                self.class_comparison.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; color: #1e293b; font-size: 14px;'>
                    <b style='color: #dc2626; font-size: 15px;'>Invalid Selection:</b><br>
                    <span style='color: #dc2626; font-weight: 500;'>
                        Both classes cannot be "All other classes". Please select one specific class.
                    </span>
                </div>
                """
                return
                
            class_a_name = self.ESA_CLASSES[self.class_a.value]
            class_b_name = self.ESA_CLASSES[self.class_b.value]

            if self.class_a.value == 999 or self.class_b.value == 999:
                difficulty = 0.5
                diff_label = "Medium"
                diff_color = "#d97706"
            else:
                difficulty = self.calculate_class_similarity(self.class_a.value, self.class_b.value)
                if difficulty > 0.6:
                    diff_label = "Hard"
                    diff_color = "#dc2626"
                elif difficulty > 0.4:
                    diff_label = "Medium"
                    diff_color = "#d97706"
                else:
                    diff_label = "Easy"
                    diff_color = "#059669"

            self.class_comparison.value = f"""
            <div style='background: #e8f5e8; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #10b981; color: #1e293b; font-size: 14px;'>
                <b style='color: #047857; font-size: 15px;'>Comparison:</b><br>
                <span style='color: #374151; font-weight: 500;'>
                    {self.class_a.value} ({class_a_name}) vs {self.class_b.value} ({class_b_name})
                </span><br>
                <span style='color: {diff_color}; font-weight: bold;'>Classification difficulty: {diff_label}</span>
            </div>
            """
        except:
            self.class_comparison.value = ""

    def calculate_class_similarity(self, class_a, class_b):
        """Calculate how similar two classes are (affects classification difficulty)"""
        similarity_scores = {
            (10, 20): 0.8, (20, 30): 0.7, (30, 40): 0.6, (90, 95): 0.9, (10, 95): 0.5,
            (20, 100): 0.4, (60, 100): 0.5, (50, 60): 0.4, (50, 80): 0.1, (70, 80): 0.15,
            (10, 50): 0.2, (10, 80): 0.2, (50, 70): 0.1, (40, 80): 0.25
        }

        key = tuple(sorted([class_a, class_b]))
        return similarity_scores.get(key, 0.35)

    def submit_feedback(self, button):
        """Handle feedback submission"""
        feedback_text = self.feedback_textarea.value.strip()
        
        if not feedback_text:
            self.feedback_status.value = """
            <div style='background: #fee2e2; padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center;'>
                <span style='color: #dc2626; font-weight: bold;'>Please provide some feedback before submitting.</span>
            </div>
            """
            return
        
        # Update the feedback in the Google Sheet
        self.update_feedback_in_sheets(feedback_text)
        
        self.feedback_status.value = """
        <div style='background: #dcfce7; padding: 12px; border-radius: 8px; margin: 10px 0; text-align: center;'>
            <span style='color: #166534; font-weight: bold; font-size: 16px;'>
                Thank you for your feedback! Your response has been recorded.
            </span>
        </div>
        """
        
        # Disable the feedback section after submission
        self.feedback_textarea.disabled = True
        self.feedback_button.disabled = True
        self.feedback_button.description = 'Feedback Submitted'

    def display_app(self):
        """Display the complete app interface with interactive map"""
        custom_css = widgets.HTML("""
        <style>
            .widget-html-content {
                color: #1e293b !important;
            }
            .widget-text input, .widget-dropdown select {
                color: #1e293b !important;
                background-color: white !important;
            }
            .widget-label {
                color: #374151 !important;
                font-weight: 500 !important;
            }
        </style>
        """)

        app_layout = widgets.VBox([
            custom_css,
            self.title,
            self.welcome_box,
            self.attribution_box,

            # Country selection section
            self.country_title,
            self.country_dropdown,
                        
            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            self.region_title,
            self.region_help,
            self.selection_map,
            self.region_status,
            self.region_info,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            self.class_title,
            widgets.HBox([self.class_a, self.class_b]),
            self.class_comparison,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            self.algo_title,
            widgets.HBox([self.algorithm, self.test_size]),
            widgets.HBox([self.n_samples, self.scale_m]),
            self.seed,

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            widgets.VBox([
                self.run_button,
                self.progress,
                self.status
            ], layout=widgets.Layout(align_items='center')),

            widgets.HTML("<hr style='border: 1px solid #e5e7eb; margin: 20px 0;'>"),

            self.output_area,
            
            # Feedback section (initially hidden)
            self.feedback_section
        ])

        display(app_layout)

    def ee_pairwise_sample_global(self, class_a, class_b, *, year, region, scale_m, n_per_class, seed):
        """Sample data from Earth Engine - from working code"""
        emb, bands, emb_mask, esa = load_stack(year)
        pair_mask = esa.eq(class_a).Or(esa.eq(class_b))
        stack = (emb.updateMask(emb_mask).updateMask(pair_mask)
                ).addBands(esa.updateMask(emb_mask).updateMask(pair_mask))

        pair_fc = stack.stratifiedSample(
            numPoints=n_per_class * 2,
            classBand='label',
            region=region,
            scale=scale_m,
            classValues=[class_a, class_b],
            classPoints=[n_per_class, n_per_class],
            seed=seed,
            geometries=True
        ).select(bands.cat(ee.List(['label'])))

        size = pair_fc.size().getInfo()
        hist = pair_fc.aggregate_histogram('label').getInfo()

        props = bands.cat(ee.List(['label']))
        attr_fc = pair_fc.map(lambda f: ee.Feature(None, f.toDictionary(props)))
        feats = attr_fc.getInfo()['features'] if size else []
        df = pd.DataFrame([f['properties'] for f in feats]) if feats else pd.DataFrame(columns=[*bands.getInfo(), 'label'])
        return df, list(bands.getInfo()), hist, size, pair_fc

    def build_model(self, name, random_state=42):
        """Build ML model - from working code"""
        name = name.lower()
        if name == "rf":
            return RandomForestClassifier(n_estimators=300, random_state=random_state, n_jobs=-1)
        elif name == "gbt":
            return GradientBoostingClassifier(random_state=random_state)
        elif name == "lgb":
            lgb = _get_lgb()
            return lgb.LGBMClassifier(
                n_estimators=300, max_depth=6, learning_rate=0.1,
                subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
                objective='binary', metric='binary_logloss',
                n_jobs=-1, random_state=random_state, verbose=-1
            )
        elif name == "xgb":
            xgb = _get_xgb()
            return xgb.XGBClassifier(
                n_estimators=300, max_depth=6, learning_rate=0.1,
                subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
                tree_method="hist", objective="binary:logistic",
                eval_metric="logloss", n_jobs=-1, random_state=random_state
            )
        else:
            raise ValueError("CLASSIFIER must be 'rf', 'gbt', 'lgb', or 'xgb'")

    def create_custom_colorbar(self, palette, layer_name, position_offset=0):
        """Create a small, custom-positioned colorbar with white background"""
        bottom_offset = 20 + (position_offset * 80)
        gradient_colors = ', '.join(palette)
        
        colorbar_html = f"""
        <div style='position: absolute; bottom: {bottom_offset}px; right: 20px; z-index: 9999; 
                    background-color: rgba(255, 255, 255, 0.5); border: 1px solid #ccc; 
                    border-radius: 4px; padding: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.2);
                    font-family: Arial, sans-serif;'>
            <div style='text-align: center; font-size: 10px; font-weight: bold; color: #333; 
                        margin-bottom: 3px; white-space: nowrap;'>
                {layer_name}
            </div>
            <div style='width: 120px; height: 12px; background: linear-gradient(to right, {gradient_colors}); 
                        border: 1px solid #999; border-radius: 2px; margin-bottom: 2px;'></div>
            <div style='display: flex; justify-content: space-between; font-size: 8px; color: #666;
                        width: 120px;'>
                <span>0.0</span>
                <span>0.5</span>
                <span>1.0</span>
            </div>
        </div>
        """
        return colorbar_html

    def run_analysis(self, button):
        """Main analysis function with Google Sheets logging"""
        with self.output_area:
            clear_output(wait=True)

            # Check if country is selected
            if self.country_dropdown.value == "Select your current country":
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        Please select the country you are currently in!
                    </span>
                </div>
                """
                return

            if self.selected_bbox is None:
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        Please draw a region on the map first!
                    </span>
                </div>
                """
                return
            
            if not EE_INITIALIZED:
                self.status.value = """
                <div style='background: #fef3c7; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #f59e0b; text-align: center;'>
                    <span style='color: #92400e; font-weight: bold; font-size: 16px;'>
                        Demo Mode Active - Earth Engine Not Available
                    </span><br>
                    <span style='color: #92400e; font-size: 14px;'>
                        For full functionality, please use the Jupyter version or authenticate Earth Engine
                    </span>
                </div>
                """
                print("This is a demonstration version of AlphaEarth Land Cover Classifier")
                print("Earth Engine authentication required for satellite data analysis")
                return

            if self.class_a.value == 999 and self.class_b.value == 999:
                self.progress.layout.visibility = 'hidden'
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        Both classes cannot be "All other classes"! Select one specific class.
                    </span>
                </div>
                """
                return

            # Show progress
            self.progress.layout.visibility = 'visible'
            self.progress.value = 10
            self.progress.description = 'Starting...'

            self.status.value = """
            <div style='background: #dbeafe; padding: 12px; border-radius: 8px; margin: 10px 0;
                        border-left: 4px solid #3b82f6; text-align: center;'>
                <span style='color: #1e40af; font-weight: bold; font-size: 16px;'>
                    Processing embeddings and generating results...
                </span>
            </div>
            """

            if self.class_a.value == self.class_b.value:
                self.progress.layout.visibility = 'hidden'
                self.status.value = """
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        Please select different classes!
                    </span>
                </div>
                """
                return

            try:
                bbox = self.selected_bbox
                roi_info = sanitize_bbox(bbox)
                region = region_from_parts(roi_info["parts"])
                center_lat, center_lon = roi_info["center"]

                self.progress.value = 30
                self.progress.description = 'Loading satellite data...'

                df, EMBED_BANDS, hist, size, pair_fc = self.ee_pairwise_sample_with_others(
                    self.class_a.value, self.class_b.value,
                    year=2020,
                    region=region,
                    scale_m=self.scale_m.value,
                    n_per_class=self.n_samples.value,
                    seed=self.seed.value
                )

                self.progress.value = 50
                self.progress.description = 'Training model...'

                counts = df['label'].value_counts().to_dict() if len(df) else {}
                have_a = counts.get(self.class_a.value, 0)
                have_b = counts.get(self.class_b.value, 0)
                MIN_PER_CLASS_REQUIRED = max(20, int(0.2 * self.n_samples.value))

                if have_a < MIN_PER_CLASS_REQUIRED or have_b < MIN_PER_CLASS_REQUIRED:
                    print(f"Pairwise comparison NOT possible in this ROI.")
                    print(f"   Got A={have_a} (class {self.class_a.value}), B={have_b} (class {self.class_b.value}).")
                    print(f"   Try enlarging the drawn region, increasing scale, or choosing different classes.")
                    self.progress.layout.visibility = 'hidden'
                    return

                label_map = {self.class_a.value: 0, self.class_b.value: 1}
                df = df[df['label'].isin(label_map.keys())].copy()
                y = df['label'].map(label_map).astype(int).values
                X = df[EMBED_BANDS].astype(float).values

                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=self.test_size.value/100, random_state=self.seed.value, stratify=y
                )

                clf = self.build_model(self.algorithm.value, random_state=self.seed.value)
                clf.fit(X_train, y_train)
                y_proba = clf.predict_proba(X_test)[:, 1]
                y_pred  = (y_proba >= 0.5).astype(int)

                self.progress.value = 70
                self.progress.description = 'Calculating metrics...'

                acc  = accuracy_score(y_test, y_pred)
                auc  = roc_auc_score(y_test, y_proba)
                report = classification_report(y_test, y_pred, output_dict=True)
                cm_matrix = confusion_matrix(y_test, y_pred)

                importances = clf.feature_importances_
                top5 = top_k_importance(importances, EMBED_BANDS, k=5, as_percent=True)
                top3 = top_k_importance(importances, EMBED_BANDS, k=3, as_percent=True)
                
                # Calculate all embedding importances
                all_importances = all_embedding_importance(importances, EMBED_BANDS, as_percent=True)

                self.progress.value = 90
                self.progress.description = 'Creating map...'

                emb_full, _, emb_mask_vis, esa_vis = load_stack(2020)

                embedding_layers = []
                for i, (display_name, importance_pct, original_band) in enumerate(top3):
                    embedding_band = emb_full.select(original_band).updateMask(emb_mask_vis).clip(region)
                    normalized_embedding = normalize_embedding(embedding_band, region)

                    embedding_layers.append({
                        'image': normalized_embedding,
                        'name': display_name,
                        'importance': importance_pct,
                        'palette_key': f'embedding_{i+1}',
                        'rank': i + 1
                    })

                # Apply the same relabeling logic for visualization
                if self.class_a.value == 999 or self.class_b.value == 999:
                    # We need to create relabeled ESA for visualization too
                    all_esa_classes = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100]
                    
                    if self.class_a.value == 999:
                        # Class A = "all others", Class B = specific
                        specific_class = self.class_b.value
                        other_classes = [c for c in all_esa_classes if c != specific_class]
                        
                        # Create combined mask for all "other" classes
                        others_mask_vis = esa_vis.eq(other_classes[0])
                        for other_class in other_classes[1:]:
                            others_mask_vis = others_mask_vis.Or(esa_vis.eq(other_class))
                        
                        # Relabel ESA for visualization
                        relabeled_esa_vis = esa_vis.where(others_mask_vis, 999)
                        
                    elif self.class_b.value == 999:
                        # Class B = "all others", Class A = specific  
                        specific_class = self.class_a.value
                        other_classes = [c for c in all_esa_classes if c != specific_class]
                        
                        # Create combined mask for all "other" classes
                        others_mask_vis = esa_vis.eq(other_classes[0])
                        for other_class in other_classes[1:]:
                            others_mask_vis = others_mask_vis.Or(esa_vis.eq(other_class))
                        
                        # Relabel ESA for visualization
                        relabeled_esa_vis = esa_vis.where(others_mask_vis, 999)
                    
                    # Use relabeled ESA for creating class images
                    classA_img = relabeled_esa_vis.eq(self.class_a.value).updateMask(relabeled_esa_vis.eq(self.class_a.value)).updateMask(emb_mask_vis).clip(region)
                    classB_img = relabeled_esa_vis.eq(self.class_b.value).updateMask(relabeled_esa_vis.eq(self.class_b.value)).updateMask(emb_mask_vis).clip(region)
                    
                else:
                    # Normal case: both classes are specific, use original ESA
                    classA_img = esa_vis.eq(self.class_a.value).updateMask(esa_vis.eq(self.class_a.value)).updateMask(emb_mask_vis).clip(region)
                    classB_img = esa_vis.eq(self.class_b.value).updateMask(esa_vis.eq(self.class_b.value)).updateMask(emb_mask_vis).clip(region)

                m = folium.Map(location=[center_lat, center_lon], zoom_start=8, control_scale=True)

                folium.TileLayer(
                    tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
                    attr='Google',
                    name='Google Satellite',
                    overlay=False,
                    control=True
                ).add_to(m)

                color_a = self.CLASS_COLORS.get(self.class_a.value, '#FF3B30')
                color_b = self.CLASS_COLORS.get(self.class_b.value, '#007AFF')

                add_ee_layer(m, classA_img, {'palette': [color_a]}, f'Pixels: {self.class_a.value} ({self.ESA_CLASSES.get(self.class_a.value,"?")})')
                add_ee_layer(m, classB_img, {'palette': [color_b]}, f'Pixels: {self.class_b.value} ({self.ESA_CLASSES.get(self.class_b.value,"?")})')

                all_colorbars_html = ""
                for i, layer in enumerate(embedding_layers):
                    try:
                        palette = EMBEDDING_PALETTES[layer['palette_key']]
                        layer_name = f"{layer['name']} ({layer['importance']:.1f}%)"

                        add_ee_layer(m, layer['image'], {
                            'min': 0,
                            'max': 1,
                            'palette': palette
                        }, layer_name)

                        colorbar_html = self.create_custom_colorbar(
                            palette, 
                            layer['name'], 
                            position_offset=i
                        )
                        all_colorbars_html += colorbar_html

                    except Exception as e:
                        print(f"  Failed to add layer {layer_name}: {str(e)}")
                        continue

                if all_colorbars_html:
                    m.get_root().html.add_child(folium.Element(all_colorbars_html))

                sample_points = pair_fc.getInfo()

                if 'features' in sample_points:
                    class_a_name = self.ESA_CLASSES.get(self.class_a.value, "?")
                    class_b_name = self.ESA_CLASSES.get(self.class_b.value, "?")

                    points_layer_a = folium.FeatureGroup(name=f'Samples: {self.class_a.value} ({class_a_name})')
                    points_layer_b = folium.FeatureGroup(name=f'Samples: {self.class_b.value} ({class_b_name})')

                    for feature in sample_points['features']:
                        coords = feature['geometry']['coordinates']
                        label = feature['properties']['label']

                        if label == self.class_a.value:
                            folium.CircleMarker(
                                location=[coords[1], coords[0]],
                                radius=4,
                                popup=f'Class {label} ({class_a_name})',
                                color='white',
                                fillColor='black',
                                fillOpacity=1,
                                weight=1
                            ).add_to(points_layer_a)
                        elif label == self.class_b.value:
                            folium.CircleMarker(
                                location=[coords[1], coords[0]],
                                radius=4,
                                popup=f'Class {label} ({class_b_name})',
                                color='black',
                                fillColor='white',
                                fillOpacity=1,
                                weight=1
                            ).add_to(points_layer_b)

                    points_layer_a.add_to(m)
                    points_layer_b.add_to(m)

                roi_coords = []
                for part in roi_info["parts"]:
                    minLon, minLat, maxLon, maxLat = part
                    roi_coords.append([
                        [minLat, minLon], [minLat, maxLon],
                        [maxLat, maxLon], [maxLat, minLon], [minLat, minLon]
                    ])

                roi_layer = folium.FeatureGroup(name='ROI Boundary')
                for coords in roi_coords:
                    folium.Polygon(
                        locations=coords,
                        color='black',
                        weight=2,
                        fill=False,
                        popup='Selected ROI Boundary'
                    ).add_to(roi_layer)
                roi_layer.add_to(m)

                folium.LayerControl().add_to(m)

                # UPDATE RESULTS DATA TO INCLUDE ALL EMBEDDINGS
                results_data = {
                    'center_lat': center_lat,
                    'center_lon': center_lon,
                    'accuracy': acc,
                    'roc_auc': auc,
                    'precision_a': report['0']['precision'],
                    'recall_a': report['0']['recall'], 
                    'f1_a': report['0']['f1-score'],
                    'precision_b': report['1']['precision'],
                    'recall_b': report['1']['recall'],
                    'f1_b': report['1']['f1-score'],
                    'confusion_matrix': cm_matrix,
                    'n_train': len(y_train),
                    'n_test': len(y_test),
                    'top_embeddings': [(name, float(pct)) for name, pct, _ in top5],
                    'all_embeddings': all_importances  # Add all 64 embeddings
                }

                # Log to Google Sheets
                self.log_interaction_to_sheets(results_data)

                # Display results
                self.display_metrics(acc, auc, report, cm_matrix, len(y_train), len(y_test), importances, EMBED_BANDS, top5)

                print("Use the layer control panel to toggle layers on/off")
                print("Color bars in bottom-right corner show normalized embedding values (0.0 - 1.0)")
                display(m)
                
                # Create download button for embedding rasters
                download_button = widgets.Button(
                    description='Download Top 3 Embedding Rasters',
                    button_style='info',
                    layout=widgets.Layout(width='400px', height='45px', margin='15px 0'),
                    tooltip='Generate download links for the top 3 most important embeddings'
                )
                
                download_status = widgets.HTML()
                
                def download_embeddings(b):
                    download_status.value = """
                    <div style='background: #dbeafe; padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center;'>
                        <strong>Generating download links...</strong> This may take a few moments.
                    </div>
                    """
                    
                    try:
                        download_links_html = "<div style='background: #f0f9ff; padding: 15px; border-radius: 8px; margin: 10px 0;'>"
                        download_links_html += """
                        <h4 style='color: #1e40af; margin: 0 0 10px 0; text-align: center;'>Download Instructions:</h4>
                        <div style='background: #fff3cd; padding: 10px; border-radius: 6px; margin-bottom: 10px; font-size: 13px; color: #92400e;'>
                            <strong>Important:</strong> Earth Engine gives complicated filenames. When downloading:<br>                            
                            <strong>Copy the suggested name</strong> below and paste it as the filename.
                        </div>
                        """
                        
                        for i, layer in enumerate(embedding_layers):
                            print(f"Generating download link for {layer['name']} (Rank {layer['rank']})...")
                            
                            clean_filename = f"{layer['name']}_{bbox[0]:.2f}_{bbox[2]:.2f}_{bbox[1]:.2f}_{bbox[3]:.2f}.tif"
                            
                            download_url = layer['image'].getDownloadURL({
                                'scale': self.scale_m.value,
                                'crs': 'EPSG:4326',
                                'region': region,
                                'format': 'GEO_TIFF'
                            })
                            
                            download_links_html += f"""
                            <div style='margin: 8px 0; padding: 12px; background: white; border-radius: 6px; border-left: 3px solid #3b82f6;'>
                                <strong>{layer['name']}</strong> ({layer['importance']:.1f}% importance)<br>
                                
                                <div style='background: #f1f5f9; padding: 8px; border-radius: 4px; margin: 6px 0; border: 1px dashed #94a3b8;'>
                                    <strong style='color: #1e40af;'>Suggested filename:</strong><br>
                                    <code style='background: #e2e8f0; padding: 2px 6px; border-radius: 3px; font-family: monospace; color: #1e293b; font-weight: bold;'>{clean_filename}</code>
                                </div>
                                
                                <a href="{download_url}" target="_blank" style='color: #2563eb; text-decoration: none; font-weight: 500;'>
                                    Click → Save As → Use filename above
                                </a>
                            </div>
                            """
                        
                        download_links_html += "</div>"
                        
                        download_status.value = """
                        <div style='background: #dcfce7; padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center;'>
                            <strong>Download links ready!</strong> Follow the instructions to save with clean filenames.
                        </div>
                        """ + download_links_html
                        
                    except Exception as e:
                        download_status.value = f"""
                        <div style='background: #fee2e2; padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center;'>
                            <strong>Download generation failed:</strong> {str(e)}
                        </div>
                        """

                download_button.on_click(download_embeddings)
                
                display(widgets.VBox([
                    download_button,
                    download_status
                ], layout=widgets.Layout(align_items='center')))

                # Show feedback section after successful analysis
                self.feedback_section.layout.visibility = 'visible'

                self.status.value = """
                <div style='background: #dcfce7; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #16a34a; text-align: center;'>
                    <span style='color: #166534; font-weight: bold; font-size: 16px;'>
                        Analysis completed! Check results above and provide feedback below.
                    </span>
                </div>
                """

                self.progress.layout.visibility = 'hidden'

            except Exception as e:
                print(f"Error: {str(e)}")
                self.progress.layout.visibility = 'hidden'
                self.status.value = f"""
                <div style='background: #fee2e2; padding: 12px; border-radius: 8px; margin: 10px 0;
                            border-left: 4px solid #ef4444; text-align: center;'>
                    <span style='color: #dc2626; font-weight: bold; font-size: 16px;'>
                        Error: {str(e)}
                    </span>
                </div>
                """

    def display_metrics(self, acc, auc, report, cm, n_train, n_test, importances, embed_bands, top5):
        """Display classification results"""
        metrics_html = f"""
        <div style='display: flex; flex-wrap: wrap; gap: 15px; margin: 20px 0;'>
            <div style='background: #dcfce7; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #166534;'>Accuracy</h4>
                <h2 style='margin: 5px 0 0 0; color: #16a34a;'>{acc*100:.1f}%</h2>
            </div>
            <div style='background: #dbeafe; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #1e40af;'>ROC AUC</h4>
                <h2 style='margin: 5px 0 0 0; color: #2563eb;'>{auc:.3f}</h2>
            </div>
            <div style='background: #f3e8ff; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #7c2d12;'>Train</h4>
                <h2 style='margin: 5px 0 0 0; color: #a855f7;'>{n_train}</h2>
            </div>
            <div style='background: #fed7d7; padding: 15px; border-radius: 8px; text-align: center; min-width: 120px;'>
                <h4 style='margin: 0; color: #7c2d12;'>Test</h4>
                <h2 style='margin: 5px 0 0 0; color: #dc2626;'>{n_test}</h2>
            </div>
        </div>
        """

        display(HTML(metrics_html))

        print("\n=== Classification Results ===")
        print(f"Per-class metrics:")
        print(f"   Class {self.class_a.value}: Precision={report['0']['precision']:.3f}, Recall={report['0']['recall']:.3f}, F1={report['0']['f1-score']:.3f}")
        print(f"   Class {self.class_b.value}: Precision={report['1']['precision']:.3f}, Recall={report['1']['recall']:.3f}, F1={report['1']['f1-score']:.3f}")

        print("\nConfusion matrix (rows=true, cols=pred):")
        cm_df = pd.DataFrame(cm, index=['A(0)','B(1)'], columns=['P0','P1'])
        print(cm_df)

        print(f"\nTop 3 Most Important Embeddings:")
        for i, (display_name, importance_pct, original_band) in enumerate(top5[:3]):
            print(f"   {i+1}. {display_name}: {importance_pct:.2f}%")

        vals = np.asarray(importances, dtype=float).ravel()
        bands = list(embed_bands)
        L = min(vals.size, len(bands))
        vals, bands = vals[:L], bands[:L]
        idx = np.array([int(b[1:]) for b in bands])
        order = np.argsort(idx)
        arr = vals[order]
        total = arr.sum() or 1.0
        arr = 100.0 * arr / total
        xlabels = [f"A{n+1:02d}" for n in idx[order]]

        plt.figure(figsize=(14, 4))
        colors = ['yellow' if i in np.argsort(importances)[-3:] else 'steelblue' for i in order]
        plt.bar(range(len(arr)), arr, color=colors)
        plt.xticks(range(len(arr)), xlabels, rotation=90)
        plt.ylabel("Importance (%)")
        
        class_a_name = self.ESA_CLASSES.get(self.class_a.value, f"Class {self.class_a.value}")
        class_b_name = self.ESA_CLASSES.get(self.class_b.value, f"Class {self.class_b.value}")
        
        plt.title(f"Embedding importance: {class_a_name} vs {class_b_name} [model = {self.algorithm.value.upper()}] (Yellow = Top 3 highest)")
        plt.tight_layout()
        plt.show()

Authenticating with Google Earth Engine...
Earth Engine already initialized!
Welcome to AlphaEarth Land Cover Classifier!
Loading app...


In [9]:
# ===============================================
# 🚀 LAUNCH THE APP
# ===============================================
app = AlphaEarthApp(PROJECT_ID)
app.display_app()    

VBox(children=(HTML(value='\n        <style>\n            .widget-html-content {\n                color: #1e29…