# Data Preparation (Handle Null)

In [None]:
import pandas as pd
df = pd.read_csv('filled_by_hybrid.csv')

In [None]:
df_filtered = df[df['SCI_NAME'].notnull() & df['PERIODFROM'].notnull()]
df_filtered[['SCI_NAME', 'PERIODFROM']]

In [3]:
import pandas as pd
import numpy as np

def handle_nulls_and_empty_strings(df):
    print("Starting null and empty string handling...")
    print(f"Initial shape: {df.shape}")
    
    # 1. Check for empty strings in all columns and convert to NaN
    empty_string_counts = {}
    for col in df.columns:
        if df[col].dtype == 'object':  # Only check string columns
            empty_count = (df[col] == '').sum()
            if empty_count > 0:
                empty_string_counts[col] = empty_count
                # Convert empty strings to NaN for consistent handling
                df[col] = df[col].replace('', np.nan)
    
    print(f"Found empty strings in {len(empty_string_counts)} columns:")
    for col, count in empty_string_counts.items():
        print(f"  - '{col}': {count} empty strings")
    
    # 2. Filter to keep only rows with valid labels
    valid_mask = df['SCI_NAME'].notna() & df['PERIODFROM'].notna()
    valid_rows = df[valid_mask]
    
    print(f"Rows with valid SCI_NAME and PERIODFROM: {len(valid_rows)} out of {len(df)}")
    
    # 3. Drop columns with high null percentages
    cols_to_drop = [
        'NAME_ENG', 'FEATURE_EN', 'GEO_DES_EN', 
        'PERIODTO', 'DIS_BY_ENG', 'DIS_BY_TH',
        'DIS_DAY', 'DIS_MONTH', 'DIS_YEAR',
        'POTENTIAL1', 'POTENTIAL2', 'POTENTIAL3', 'POTENTIAL4', 'POTENTIAL5', 'POTENTIAL6', 'POTENTIAL7',
        'STATUS', 'UPDATEBY', 'UPDATEDATE', 'OWNER_ENG'
    ]
    valid_rows = valid_rows.drop(columns=[col for col in cols_to_drop if col in valid_rows.columns])
    
    # 4. Create missing flags for important columns
    valid_rows['formation_missing'] = valid_rows['FORMATION'].isna().astype(int)
    valid_rows['geo_group_missing'] = valid_rows['GEO_GROUP'].isna().astype(int)
    valid_rows['has_locality'] = valid_rows['LOCALITY'].notna().astype(int)
    valid_rows['has_geo_desc'] = valid_rows['GEO_DES_TH'].notna().astype(int)
    
    if 'FEATURE_TH' in valid_rows.columns:
        valid_rows['has_feature'] = valid_rows['FEATURE_TH'].notna().astype(int)
    
    # 5. Fill missing values - Categorical
    categorical_fills = {
        'FORMATION': 'Unknown',
        'GEO_GROUP': 'Unknown'
    }
    if 'FEATURE_TH' in valid_rows.columns:
        categorical_fills['FEATURE_TH'] = 'Unknown'
    
    valid_rows = valid_rows.fillna(categorical_fills)
    
    # 6. Fill missing values - Text
    text_fields = ['LOCALITY', 'GEO_DES_TH', 'FOS_DES_TH', 'IMPORTA_TH', 'F_PART']
    text_fills = {field: 'No information' for field in text_fields if field in valid_rows.columns}
    valid_rows = valid_rows.fillna(text_fills)
    
    # 7. Fill missing values - Numeric with province means
    numeric_cols = ['UTM_E', 'UTM_N', 'geometry_x', 'geometry_y']
    for col in numeric_cols:
        if col in valid_rows.columns:
            # Calculate province means
            province_col = 'PROVINCE_fossil' if 'PROVINCE_fossil' in valid_rows.columns else 'PROVINCE'
            province_means = valid_rows.groupby(province_col)[col].transform('mean')
            # Fill missing values with province means
            valid_rows[col] = valid_rows[col].fillna(province_means)
    
    # 8. For any remaining NaNs in numeric columns, use global mean
    numeric_cols_all = valid_rows.select_dtypes(include=['float64', 'int64']).columns
    for col in numeric_cols_all:
        if valid_rows[col].isna().sum() > 0:
            valid_rows[col] = valid_rows[col].fillna(valid_rows[col].mean())
    
    # 9. Fill any remaining NaNs with appropriate values based on column type
    for col in valid_rows.columns:
        if valid_rows[col].isna().sum() > 0:
            if valid_rows[col].dtype == 'object':  # String columns
                valid_rows[col] = valid_rows[col].fillna('No information')
            elif np.issubdtype(valid_rows[col].dtype, np.number):  # Numeric columns
                valid_rows[col] = valid_rows[col].fillna(0)
    
    # 10. Final check for empty strings
    empty_after = {}
    for col in valid_rows.columns:
        if valid_rows[col].dtype == 'object':  # Only check string columns
            empty_count = (valid_rows[col] == '').sum()
            if empty_count > 0:
                empty_after[col] = empty_count
                # Replace remaining empty strings with 'No information'
                valid_rows[col] = valid_rows[col].replace('', 'No information')
    
    if empty_after:
        print("\nFound and replaced remaining empty strings in these columns:")
        for col, count in empty_after.items():
            print(f"  - '{col}': {count} empty strings")
    else:
        print("\nNo remaining empty strings found.")
    
    print(f"Final shape after processing: {valid_rows.shape}")
    print(f"Remaining nulls: {valid_rows.isna().sum().sum()}")
    
    return valid_rows

# Load the data
df = pd.read_csv('filled_by_hybrid.csv')
print(f"Original data shape: {df.shape}")

# Handle nulls and empty strings, keep only valid label rows
cleaned_df = handle_nulls_and_empty_strings(df)

# Verify that no empty strings remain
empty_strings_left = False
for col in cleaned_df.columns:
    if cleaned_df[col].dtype == 'object':
        if (cleaned_df[col] == '').any():
            empty_strings_left = True
            print(f"Warning: Column '{col}' still has empty strings!")

if not empty_strings_left:
    print("Success! No empty strings remain in the dataset.")

# Save the cleaned data
cleaned_df.to_csv('cleaned_fossil_data.csv', index=False)

# Display column statistics after processing
print("\nColumn statistics after processing:")
for col in cleaned_df.columns:
    if cleaned_df[col].dtype == 'object':
        print(f"{col}: {cleaned_df[col].nunique()} unique values, no empty strings")
    else:
        print(f"{col}: min={cleaned_df[col].min()}, max={cleaned_df[col].max()}")

Original data shape: (12016, 69)
Starting null and empty string handling...
Initial shape: (12016, 69)
Found empty strings in 0 columns:
Rows with valid SCI_NAME and PERIODFROM: 9817 out of 12016

No remaining empty strings found.
Final shape after processing: (9817, 54)
Remaining nulls: 0
Success! No empty strings remain in the dataset.

Column statistics after processing:
FOSSIL_ID: min=20200800005, max=20230600069
SCI_NAME: 222 unique values, no empty strings
COM_NAME: 81 unique values, no empty strings
F_GROUP: 5 unique values, no empty strings
F_TYPE: 36 unique values, no empty strings
F_PART: 451 unique values, no empty strings
DIS_NAME: 88 unique values, no empty strings
PROVINCE_fossil: 26 unique values, no empty strings
DISTRICT_fossil: 58 unique values, no empty strings
TAMBOM: 79 unique values, no empty strings
REGIS_CODE: 114 unique values, no empty strings
location_key: 79 unique values, no empty strings
SITE_ID: min=1.0, max=517.0
SITE_CODE: min=136019.0, max=581006.0
NAM

In [4]:
import pandas as pd
import numpy as np

def analyze_dataset(file_path):
    """Generate a comprehensive analysis of the dataset for review"""
    
    # Load the data
    df = pd.read_csv(file_path)
    
    # Basic info
    print("=" * 80)
    print(f"DATASET ANALYSIS: {file_path}")
    print("=" * 80)
    print(f"Rows: {df.shape[0]}, Columns: {df.shape[1]}")
    
    # Check for nulls and empty strings
    null_counts = df.isnull().sum()
    empty_string_counts = {}
    
    for col in df.columns:
        if df[col].dtype == 'object':
            empty_string_counts[col] = (df[col] == '').sum()
    
    # Label distribution
    print("\nLABELS DISTRIBUTION:")
    print("-" * 80)
    
    if 'SCI_NAME' in df.columns:
        print(f"SCI_NAME unique values: {df['SCI_NAME'].nunique()}")
        print("Top 5 most common SCI_NAME values:")
        print(df['SCI_NAME'].value_counts().head(5).to_string())
    
    if 'PERIODFROM' in df.columns:
        print(f"\nPERIODFROM unique values: {df['PERIODFROM'].nunique()}")
        print("PERIODFROM distribution:")
        print(df['PERIODFROM'].value_counts().to_string())
    
    # Columns with nulls
    print("\nCOLUMNS WITH NULL VALUES:")
    print("-" * 80)
    if null_counts.sum() > 0:
        print(null_counts[null_counts > 0].to_string())
    else:
        print("No null values found in any column.")
    
    # Columns with empty strings
    print("\nCOLUMNS WITH EMPTY STRINGS:")
    print("-" * 80)
    empty_cols = {k: v for k, v in empty_string_counts.items() if v > 0}
    if empty_cols:
        for col, count in empty_cols.items():
            print(f"{col}: {count} empty strings")
    else:
        print("No empty strings found in any column.")
    
    # Data types overview
    print("\nDATA TYPES:")
    print("-" * 80)
    print(df.dtypes.value_counts().to_string())
    
    # Column categories
    print("\nCOLUMN CATEGORIES:")
    print("-" * 80)
    
    # Categorize columns
    cat_cols = []
    num_cols = []
    text_cols = []
    binary_cols = []
    id_cols = []
    
    for col in df.columns:
        if col in ['FOSSIL_ID', 'SITE_ID', 'GlobalID', 'OBJECTID', 'REGIS_CODE']:
            id_cols.append(col)
        elif df[col].dtype == 'object' and df[col].nunique() > 20:
            text_cols.append(col)
        elif df[col].dtype == 'object':
            cat_cols.append(col)
        elif df[col].nunique() <= 2:
            binary_cols.append(col)
        else:
            num_cols.append(col)
    
    print(f"ID columns ({len(id_cols)}): {', '.join(id_cols)}")
    print(f"Numeric columns ({len(num_cols)}): {', '.join(num_cols)}")
    print(f"Categorical columns ({len(cat_cols)}): {', '.join(cat_cols)}")
    print(f"Text columns ({len(text_cols)}): {', '.join(text_cols)}")
    print(f"Binary columns ({len(binary_cols)}): {', '.join(binary_cols)}")
    
    # Sample data
    print("\nSAMPLE DATA (5 rows):")
    print("-" * 80)
    print(df.head(5).to_string())
    
    return "Analysis complete"

# Run the analysis - replace with your file path
analyze_dataset('cleaned_fossil_data.csv')

DATASET ANALYSIS: cleaned_fossil_data.csv
Rows: 9817, Columns: 54

LABELS DISTRIBUTION:
--------------------------------------------------------------------------------
SCI_NAME unique values: 222
Top 5 most common SCI_NAME values:
Phuwiangosaurus sirindhornae    2085
N/A (Dinosaur)                  1906
N/A (Brachiopod)                1072
Sauropod                        1024
N/A (Crocodile)                  481

PERIODFROM unique values: 15
PERIODFROM distribution:
ครีเทเชียสตอนต้น                      5483
เพอร์เมียน                            3485
เพอร์เมียนตอนต้น                       291
จูแรสสิกตอนปลาย                        182
ออร์โดวิเชียน                          150
ไทรแอสสิกตอนปลาย                       102
จูแรสซิก                                99
แคมเบรียนตอนปลาย                         7
ครีเทเซียสตอนต้น                         4
ไทรแอสสิก                                4
ออร์โดวีเชียน-ไซลูเรียน                  4
จูแรสสิกตอนปลายถึงครีเทเชียสตอนต้น       2
ครีเทเชียส  

'Analysis complete'

# Data Preparation and Feature Extraction

In [2]:
import pandas as pd
import numpy as np
import utm
import re
import os
from datetime import datetime
import time
from sklearn.preprocessing import StandardScaler
import requests
from sklearn.preprocessing import OneHotEncoder
import warnings
import matplotlib.pyplot as plt
import folium
from folium.plugins import MarkerCluster
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import AutoTokenizer, AutoModel
from sentinelhub import (
    SHConfig, SentinelHubRequest, MimeType, CRS, BBox,
    DataCollection, MosaickingOrder
)
import rasterio
from rasterio.transform import from_bounds
from rasterio.warp import calculate_default_transform, reproject, Resampling
import concurrent.futures
from collections import defaultdict
warnings.filterwarnings('ignore')

# For screenshot capturing
try:
    from selenium import webdriver
    from selenium.webdriver.chrome.options import Options
    from selenium.webdriver.chrome.service import Service
    from webdriver_manager.chrome import ChromeDriverManager
    SELENIUM_AVAILABLE = True
except ImportError:
    SELENIUM_AVAILABLE = False
    print("Warning: Selenium not available. Map screenshots will use placeholders.")

# Sentinel Hub credentials
SENTINEL_INSTANCE_ID = "6042f305-e734-420c-bc66-bbf0de1ac960"
SENTINEL_CLIENT_ID = "ac14f011-f105-491c-8b31-4568e2f70590"
SENTINEL_CLIENT_SECRET = "1BvZ1rCPBFYygsiiXKfu2KVQNpFgzWNl"
OPENWEATHER_API_KEY = "3718c28f9dc7148926ffc3956330c0fa"

def extract_features(df, save_interim=True, use_real_data=True, satellite_batch_size=20, 
                    satellite_delay_seconds=30, text_batch_size=100):
    """
    Main function to extract all features from the cleaned dataset with batch processing
    
    Args:
        df: Cleaned pandas DataFrame
        save_interim: Whether to save intermediate results
        use_real_data: Whether to use real satellite and text data
        satellite_batch_size: Size of batches for satellite processing
        satellite_delay_seconds: Delay between satellite batches in seconds
        text_batch_size: Size of batches for text processing
        
    Returns:
        DataFrame with all extracted features
    """
    print("Starting feature extraction process with batch processing...")
    
    # Create copy to avoid modifying original
    processed_df = df.copy()
    
    # Create output directories
    os.makedirs('satellite_data', exist_ok=True)
    os.makedirs('satellite_maps', exist_ok=True)
    os.makedirs('interim_results', exist_ok=True)
    
    # 1. Coordinate & Spatial Features
    processed_df = add_coordinate_features(processed_df)
    if save_interim:
        processed_df.to_csv('interim_results/1_coordinates_features.csv', index=False)
        print("  Saved interim results after coordinate processing")
    
    # 2. Satellite & Remote Sensing Features 
    processed_df = add_satellite_features(
        processed_df, 
        fetch_real_data=use_real_data,
        batch_size=satellite_batch_size,
        delay_seconds=satellite_delay_seconds,
        process_all=True  # Process all data points
    )
    if save_interim:
        processed_df.to_csv('interim_results/2_satellite_features.csv', index=False)
        print("  Saved interim results after satellite processing")
    
    # 3. Text & Semantic Features
    processed_df = add_text_features(
        processed_df, 
        use_wangchan_bert=use_real_data,
        batch_size=text_batch_size,
        save_interim=save_interim
    )
    if save_interim:
        processed_df.to_csv('interim_results/3_text_features.csv', index=False)
        print("  Saved interim results after text processing")
    
    # 4. Temporal Features
    processed_df = add_temporal_features(processed_df)
    
    # 5. Statistical & Frequency Features
    processed_df = add_statistical_features(processed_df)
    
    # 6. Categorical Features
    processed_df = add_categorical_features(processed_df)
    
    # 7. Scale Numeric Features
    processed_df = scale_numeric_features(processed_df)
    
    # 8. Create map visualization
    if save_interim and use_real_data:
        map_path = create_map_visualization(processed_df, 'fossil_map.html')
        print(f"Interactive map created at {map_path}")
    
    # 9. Create dataset summary
    if save_interim:
        create_dataset_summary(processed_df, 'dataset_summary.html')
    
    # Save final result
    if save_interim:
        # Full dataset
        processed_df.to_csv('fossil_features_full.csv', index=False)
        
        # Simplified dataset with key features for training
        training_cols = select_training_features(processed_df)
        processed_df[training_cols].to_csv('fossil_features_training.csv', index=False)
        
    print("Feature extraction complete!")
    print(f"Full dataset saved to fossil_features_full.csv")
    print(f"Training dataset saved to fossil_features_training.csv")
    
    return processed_df

def select_training_features(df):
    """Select the most important features for model training"""
    # Core ID and target features
    essential_cols = ['FOSSIL_ID', 'SCI_NAME', 'COM_NAME', 'F_GROUP', 'F_TYPE']
    
    # Spatial features
    spatial_cols = ['latitude', 'longitude', 'pangaea_lat', 'pangaea_lon', 
                   'distance_from_centroid', 'nearest_related_site_km']
    
    # Satellite features
    satellite_cols = ['ndvi', 'red_band', 'green_band', 'blue_band', 'nir_band', 
                      'has_real_satellite_data', 'satellite_img_path']
    
    # Text features
    text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
    text_cols += ['has_bert_embeddings']
    text_cols += [col for col in df.columns if col.startswith('has_') 
                 and not col == 'has_real_satellite_data' and not col == 'has_bert_embeddings']
    
    # Temporal features
    temporal_cols = ['period_age_mya', 'geological_era']
    
    # Statistical features
    stat_cols = ['species_frequency', 'location_frequency', 'province_density', 
                'period_frequency', 'group_frequency']
    
    # Other important features
    other_cols = ['group_id', 'period_id']
    
    # Combine all selected columns
    selected_cols = (essential_cols + spatial_cols + satellite_cols + 
                    text_cols + temporal_cols + stat_cols + other_cols)
    
    # Ensure all columns exist in the dataframe
    final_cols = [col for col in selected_cols if col in df.columns]
    
    print(f"Selected {len(final_cols)} key features for training dataset")
    return final_cols

def capture_map_screenshot(html_path, output_path, width=800, height=600):
    """
    Capture a screenshot of an HTML map and carefully remove UI elements without affecting tiles
    """
    if not SELENIUM_AVAILABLE:
        return False
        
    try:
        print(f"  Capturing screenshot of {html_path}")
        
        # Set up Chrome options
        options = Options()
        options.add_argument("--headless")
        options.add_argument("--disable-gpu")
        options.add_argument(f"--window-size={width},{height}")
        
        # Initialize Chrome driver
        driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)
        
        # Load the HTML file
        file_url = f"file://{os.path.abspath(html_path)}"
        driver.get(file_url)
        
        # Wait for the map to fully load - increased wait time
        time.sleep(5)
        
        # Use a more careful approach to hide elements rather than removing them
        driver.execute_script("""
            // Hide controls but keep tiles visible
            var elementsToHide = document.querySelectorAll(
                '.leaflet-control-container, .leaflet-marker-icon, .leaflet-marker-shadow, ' + 
                '.leaflet-popup, .leaflet-overlay-pane, .leaflet-marker-pane, ' + 
                '.leaflet-tooltip-pane, .leaflet-popup-pane'
            );
            
            elementsToHide.forEach(function(el) {
                el.style.display = 'none';
            });
            
            // Make sure tile pane is visible
            var tilePanes = document.querySelectorAll('.leaflet-tile-pane, .leaflet-tile-container, .leaflet-tile');
            tilePanes.forEach(function(el) {
                el.style.display = 'block';
                el.style.opacity = '1';
            });
            
            // Add debugging info to console
            console.log('Tiles visibility check: ' + 
                document.querySelectorAll('.leaflet-tile').length + ' tiles found');
        """)
        
        # Wait for the script to execute completely
        time.sleep(2)
        
        # Take screenshot
        driver.save_screenshot(output_path)
        
        # Close browser
        driver.quit()
        
        print(f"  Clean satellite screenshot saved to {output_path}")
        return True
        
    except Exception as e:
        print(f"  Error capturing screenshot: {e}")
        return False

def add_coordinate_features(df):
    """Add coordinate-based features and track Pangaea coordinate source"""
    print("Adding coordinate features...")
    
    # Count original missing UTM values
    missing_utm = df[df['UTM_E'].isna() | df['UTM_N'].isna() | df['ZONE'].isna()].shape[0]
    print(f"  Original missing UTM values: {missing_utm} rows")
    
    # 1. Fix out-of-range UTM values before conversion
    # Standard UTM easting should be between 100,000 and 1,000,000 meters
    # For values much larger, divide by 10 to fix possible decimal errors
    df['UTM_E_fixed'] = df['UTM_E'].copy()
    df['UTM_N_fixed'] = df['UTM_N'].copy()
    
    # Fix too large easting values
    large_easting = (df['UTM_E'] > 999999) & (df['UTM_E'].notna())
    if large_easting.any():
        print(f"  Fixing {large_easting.sum()} rows with too large easting values")
        df.loc[large_easting, 'UTM_E_fixed'] = df.loc[large_easting, 'UTM_E'] / 10
    
    # Fix too small easting values
    small_easting = (df['UTM_E'] < 100000) & (df['UTM_E'].notna())
    if small_easting.any():
        print(f"  Fixing {small_easting.sum()} rows with too small easting values")
        df.loc[small_easting, 'UTM_E_fixed'] = df.loc[small_easting, 'UTM_E'] * 10
    
    # 2. Convert UTM to lat/long
    def utm_to_latlon(row):
        try:
            utm_e = float(row['UTM_E_fixed'])
            utm_n = float(row['UTM_N_fixed'])
            zone_text = str(row['ZONE']).strip()
            
            # Extract the numeric part and hemisphere
            zone_num = int(re.findall(r'\d+', zone_text)[0])
            hemisphere = 'N' if 'N' in zone_text else 'S'
            
            # Check if the coordinates are within valid UTM range
            if 100000 <= utm_e <= 999999 and utm_n > 0:
                lat, lon = utm.to_latlon(utm_e, utm_n, zone_num, hemisphere)
                return pd.Series({'latitude': lat, 'longitude': lon})
            else:
                return pd.Series({'latitude': np.nan, 'longitude': np.nan})
                
        except Exception as e:
            # Silently return NaN for conversion errors
            return pd.Series({'latitude': np.nan, 'longitude': np.nan})
    
    # Apply UTM conversion
    df[['latitude', 'longitude']] = df.apply(utm_to_latlon, axis=1)
    
    # Count missing values after conversion
    missing_coords = df[df['latitude'].isna() | df['longitude'].isna()].shape[0]
    print(f"  Missing coordinates after conversion: {missing_coords} rows")
    
    # Fill missing values with province means
    province_col = 'PROVINCE_fossil' if 'PROVINCE_fossil' in df.columns else 'PROVINCE'
    
    # Calculate province averages for valid coordinates
    valid_coords = df.dropna(subset=['latitude', 'longitude'])
    if valid_coords.shape[0] > 0:
        province_means = valid_coords.groupby(province_col)[['latitude', 'longitude']].mean()
        
        # Fill missing coordinates with province means
        for idx, row in df[df['latitude'].isna() | df['longitude'].isna()].iterrows():
            province = row[province_col]
            if province in province_means.index:
                df.at[idx, 'latitude'] = province_means.loc[province, 'latitude']
                df.at[idx, 'longitude'] = province_means.loc[province, 'longitude']
    
    # Fill any remaining NaNs with dataset mean
    if df['latitude'].isna().any() or df['longitude'].isna().any():
        lat_mean = df['latitude'].dropna().mean()
        lon_mean = df['longitude'].dropna().mean()
        df['latitude'] = df['latitude'].fillna(lat_mean)
        df['longitude'] = df['longitude'].fillna(lon_mean)
    
    print(f"  Coordinates filled - remaining missing: {df['latitude'].isna().sum()} rows")
    
    # 3. Add period age mapping
    # Mapping of period names to geological ages in millions of years
    period_ages = {
        'เพอร์เมียน': 299,
        'เพอร์เมียนตอนต้น': 290,
        'ไทรแอสซิก': 251,
        'ไทรแอสซิกตอนต้น': 245,
        'ไทรแอสสิกตอนปลาย': 220,
        'จูแรสซิก': 201,
        'จูแรสสิกตอนปลาย': 160,
        'ครีเทเชียส': 145,
        'ครีเทเชียสตอนต้น': 140,
        'ครีเทเซียสตอนต้น': 140,
        'ออร์โดวิเชียน': 485,
        'ออร์โดวีเชียน': 485,
        'ออร์โดวีเชียน-ไซลูเรียน': 450,
        'แคมเบรียนตอนปลาย': 500,
        'ยออร์โดวิเชียน': 485,
        'จูแรสสิกตอนปลายถึงครีเทเชียสตอนต้น': 150,
    }
    
    # Add numerical age
    df['period_age_mya'] = df['PERIODFROM'].map(period_ages).fillna(0)
    
    # 4. Calculate Pangaea coordinates using enhanced method that returns ratio
    print("  Attempting real Pangaea coordinates calculation...")
    df, real_ratio = calculate_pangaea_coordinates(df)
    
    # Add the real ratio as a column in the dataframe for reference
    df['pangaea_real_ratio'] = real_ratio
    
    # Count real vs approximated coordinates
    if 'pangaea_coords_approximated' in df.columns:
        real_count = (df['pangaea_coords_approximated'] == 0).sum()
        approx_count = (df['pangaea_coords_approximated'] == 1).sum()
        print(f"  Pangaea coordinates: {real_count} real, {approx_count} approximated")
        print(f"  Real coordinate ratio: {real_ratio:.2%}")
    
    # 5. Distance from province centroid
    print("  Calculating distance from province centroid...")
    # Group by province
    province_groups = df.groupby(province_col)
    
    # Calculate distance for each group
    for province, group in province_groups:
        if len(group) <= 1:
            continue
            
        # Get province center
        center_lat = group['latitude'].mean()
        center_lon = group['longitude'].mean()
        
        # Calculate distances
        for idx in group.index:
            lat = df.loc[idx, 'latitude']
            lon = df.loc[idx, 'longitude']
            
            # Simple Euclidean distance (in degrees)
            distance = np.sqrt((lat - center_lat)**2 + (lon - center_lon)**2)
            df.loc[idx, 'distance_from_centroid'] = distance
    
    # Fill any NaN values in distance
    df['distance_from_centroid'] = df['distance_from_centroid'].fillna(0)
    
    # 6. Calculate distance to nearest related fossil sites
    print("  Calculating distance to nearest related sites...")
    
    # Function to calculate nearest distance within a group
    def calc_nearest(group, lat_col='latitude', lon_col='longitude'):
        # Create a dictionary to store results
        nearest_distances = {}
        
        # Convert to numpy for faster calculations
        ids = group.index.values
        lats = group[lat_col].values
        lons = group[lon_col].values
        
        # For each point
        for i, idx in enumerate(ids):
            lat1, lon1 = lats[i], lons[i]
            
            # Calculate distances to all other points
            min_dist = float('inf')
            for j, _ in enumerate(ids):
                if i == j:
                    continue
                    
                lat2, lon2 = lats[j], lons[j]
                
                # Simple Euclidean distance
                dist = np.sqrt((lat1 - lat2)**2 + (lon1 - lon2)**2)
                
                if dist < min_dist:
                    min_dist = dist
            
            # Store the minimum distance
            if min_dist != float('inf'):
                nearest_distances[idx] = min_dist
            else:
                nearest_distances[idx] = 0  # No other points in group
        
        return nearest_distances
    
    # Group by species to find related fossils
    species_groups = df.groupby('SCI_NAME')
    
    # Calculate nearest distances for each species
    for species, group in species_groups:
        if len(group) <= 1:
            continue
            
        # Calculate nearest distances
        distances = calc_nearest(group)
        
        # Update dataframe
        for idx, dist in distances.items():
            df.loc[idx, 'nearest_related_site_deg'] = dist
    
    # Fill any NaN values
    df['nearest_related_site_deg'] = df['nearest_related_site_deg'].fillna(0)
    
    # Convert to approximate kilometers (1 degree ≈ 111 km)
    df['nearest_related_site_km'] = df['nearest_related_site_deg'] * 111
    
    return df

def calculate_pangaea_coordinates(df, models_dir=os.getcwd()):
    """
    Calculate Pangaea coordinates for given locations and ages using GPlates.
    Returns ratio of real vs. approximated coordinates in the result.

    Args:
        df (pd.DataFrame): DataFrame with latitude, longitude, and period_age_mya columns.
        models_dir (str): Path to directory containing .rot (rotation) and .gpml (feature) files.

    Returns:
        tuple: (DataFrame with new columns, real_ratio)
    """
    print("🔄 **Calculating Pangaea coordinates...**")
    
    # Initialize counters for tracking real vs. approximated coordinates
    total_points = len(df)
    real_points = 0
    approx_points = 0

    required_cols = ['latitude', 'longitude', 'period_age_mya']
    if not all(col in df.columns for col in required_cols):
        print("❌ Missing required columns in DataFrame!")
        df = calculate_pangaea_coordinates_simple(df)
        return df, 0.0  # Return ratio of 0 - all approximated

    try:
        import pygplates
        print("✅ Successfully imported pygplates")
    except ImportError:
        print("❌ pygplates not available, using approximation method")
        df = calculate_pangaea_coordinates_simple(df)
        return df, 0.0  # Return ratio of 0 - all approximated

    if not os.path.exists(models_dir):
        print(f"❌ Directory {models_dir} not found")
        df = calculate_pangaea_coordinates_simple(df)
        return df, 0.0  # Return ratio of 0 - all approximated

    rotation_files = [os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith('.rot')]
    feature_files = [os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith('.gpml')]

    if not rotation_files:
        print("⚠️ No rotation files found. Using simplified approximation.")
        df = calculate_pangaea_coordinates_simple(df)
        return df, 0.0  # Return ratio of 0 - all approximated

    if not feature_files:
        print("⚠️ No feature files found. Using simplified approximation.")
        df = calculate_pangaea_coordinates_simple(df)
        return df, 0.0  # Return ratio of 0 - all approximated

    print(f"📊 Found {len(rotation_files)} rotation files and {len(feature_files)} feature files")

    # Try both plate IDs - the one from your test and the original one
    THAILAND_PLATE_IDS = [619, 60301]  # Try your test ID first, then the original

    best_results = None
    best_success_rate = 0
    best_plate_id = None

    for plate_id in THAILAND_PLATE_IDS:
        print(f"🔍 Trying with Thailand Plate ID: {plate_id}")
        
        for rot_file in rotation_files:
            rot_name = os.path.basename(rot_file)
            try:
                rotation_model = pygplates.RotationModel(rot_file)
                print(f"✅ Loaded rotation model: {rot_name}")

                # Check if we even need feature files by testing direct reconstruction
                can_reconstruct_directly = True
                test_point = pygplates.PointOnSphere(13.7, 100.5)  # Bangkok coordinates
                test_age = 200  # Test with 200 Ma
                try:
                    reconstructed = rotation_model.get_rotation(test_age, plate_id) * test_point
                    print(f"✅ Direct reconstruction works with plate ID {plate_id}!")
                except Exception:
                    can_reconstruct_directly = False
                    
                if can_reconstruct_directly:
                    # If direct reconstruction works, try that approach for all points
                    results = []
                    success_count = 0
                    
                    for idx, row in df.iterrows():
                        lat, lon, age = row['latitude'], row['longitude'], row['period_age_mya']
                        
                        if pd.isna(lat) or pd.isna(lon) or age <= 0:
                            results.append((np.nan, np.nan))
                            continue
                            
                        try:
                            point = pygplates.PointOnSphere(float(lat), float(lon))
                            rotation = rotation_model.get_rotation(float(age), plate_id)
                            
                            if rotation:
                                reconstructed_point = rotation * point
                                lat_lon = reconstructed_point.to_lat_lon()
                                results.append((lat_lon[0], lat_lon[1]))
                                success_count += 1
                                continue
                            
                            # If rotation failed
                            results.append((np.nan, np.nan))
                            
                        except Exception as e:
                            print(f"⚠️ Error reconstructing point: {e}")
                            results.append((np.nan, np.nan))
                    
                    # Success rate calculation
                    success_rate = success_count / len(df) if len(df) > 0 else 0
                    print(f"🎯 Direct Success Rate: {success_rate:.2%} ({success_count}/{len(df)} points)")
                    
                    # Keep track of best result
                    if success_rate > best_success_rate:
                        best_success_rate = success_rate
                        best_results = results
                        best_plate_id = plate_id
                        
                    # If success rate is good enough, use this result
                    if success_rate > 0.7:
                        df['pangaea_lat'] = [result[0] for result in results]
                        df['pangaea_lon'] = [result[1] for result in results]
                        df['pangaea_coords_approximated'] = 0  # Real coordinates
                        real_points = success_count
                        approx_points = total_points - success_count
                        real_ratio = real_points / total_points if total_points > 0 else 0
                        print(f"✅ Using direct method with {rot_name} and Plate ID {plate_id} (Success Rate: {success_rate:.2%})")
                        print(f"📊 Real coordinate ratio: {real_ratio:.2%} ({real_points}/{total_points})")
                        return df, real_ratio
                
                # If direct reconstruction failed or had low success, try with feature files
                for feat_file in feature_files:
                    feat_name = os.path.basename(feat_file)
                    print(f"  📌 Trying: {rot_name} + {feat_name}")

                    try:
                        results = []
                        success_count = 0

                        for idx, row in df.iterrows():
                            lat, lon, age = row['latitude'], row['longitude'], row['period_age_mya']

                            if pd.isna(lat) or pd.isna(lon) or age <= 0:
                                results.append((np.nan, np.nan))
                                continue

                            try:
                                point = pygplates.PointOnSphere(float(lat), float(lon))

                                # Try reconstruct_point
                                reconstructed_point = pygplates.reconstruct_point(rotation_model, point, float(age), plate_id)
                                if reconstructed_point:
                                    results.append((reconstructed_point.get_latitude(), reconstructed_point.get_longitude()))
                                    success_count += 1
                                    continue

                                # Try full reconstruct()
                                feature = pygplates.Feature()
                                feature.set_geometry(point)
                                feature.set_reconstruction_plate_id(plate_id)

                                reconstructed_features = []
                                pygplates.reconstruct([feature], rotation_model, reconstructed_features, float(age))

                                if reconstructed_features:
                                    reconstructed_geom = reconstructed_features[0].get_reconstructed_geometry()
                                    results.append((reconstructed_geom.get_latitude(), reconstructed_geom.get_longitude()))
                                    success_count += 1
                                    continue

                                # If all failed
                                results.append((np.nan, np.nan))

                            except Exception as e:
                                print(f"⚠️ Point error: {e}")
                                results.append((np.nan, np.nan))

                        # Success Rate Calculation
                        success_rate = success_count / len(df) if len(df) > 0 else 0
                        print(f"    🎯 Success Rate: {success_rate:.2%} ({success_count}/{len(df)} points)")

                        # Keep track of best result
                        if success_rate > best_success_rate:
                            best_success_rate = success_rate
                            best_results = results
                            best_plate_id = plate_id

                        # If success rate is good enough, use this result
                        if success_rate > 0.7:
                            df['pangaea_lat'] = [result[0] for result in results]
                            df['pangaea_lon'] = [result[1] for result in results]
                            df['pangaea_coords_approximated'] = 0  # Real coordinates
                            real_points = success_count
                            approx_points = total_points - success_count
                            real_ratio = real_points / total_points if total_points > 0 else 0
                            print(f"  ✅ Using {rot_name} + {feat_name} with Plate ID {plate_id} (Success Rate: {success_rate:.2%})")
                            print(f"📊 Real coordinate ratio: {real_ratio:.2%} ({real_points}/{total_points})")
                            return df, real_ratio

                    except Exception as e:
                        print(f"    ❌ Error with feature file {feat_name}: {e}")

            except Exception as e:
                print(f"❌ Error loading rotation model {rot_name}: {e}")

    # If real reconstruction partially succeeded, use best available and fill in the rest
    if best_results and best_success_rate > 0.2:
        print(f"⚠️ Using best available result with Plate ID {best_plate_id} (Success Rate: {best_success_rate:.2%})")
        
        # Set the successful reconstructions
        df['pangaea_lat'] = np.nan
        df['pangaea_lon'] = np.nan
        df['pangaea_coords_approximated'] = 1  # Start with all approximated
        
        successful_count = 0
        for i, result in enumerate(best_results):
            if not (pd.isna(result[0]) or pd.isna(result[1])):
                df.iloc[i, df.columns.get_loc('pangaea_lat')] = result[0]
                df.iloc[i, df.columns.get_loc('pangaea_lon')] = result[1]
                df.iloc[i, df.columns.get_loc('pangaea_coords_approximated')] = 0  # Mark as real
                successful_count += 1
        
        # Fill in missing values with approximation
        missing_indices = df[df['pangaea_lat'].isna()].index
        for idx in missing_indices:
            lat = df.loc[idx, 'latitude']
            lon = df.loc[idx, 'longitude']
            age = df.loc[idx, 'period_age_mya']
            
            approx_lat, approx_lon = approximate_pangaea_coords(lat, lon, age)
            df.loc[idx, 'pangaea_lat'] = approx_lat
            df.loc[idx, 'pangaea_lon'] = approx_lon
            df.loc[idx, 'pangaea_coords_approximated'] = 1  # Approximated
        
        real_ratio = successful_count / total_points
        print(f"📊 Mixed approach ratio: {real_ratio:.2%} ({successful_count}/{total_points} real coordinates)")
        return df, real_ratio

    # Fallback to Approximation for all points
    print("⚠️ No good reconstructions found. Using approximation method for all points.")
    df = calculate_pangaea_coordinates_simple(df)
    return df, 0.0  # All approximated

def calculate_pangaea_coordinates_simple(df):
    """
    Approximate Pangaea coordinates using a simple transformation.
    This is a fallback method if GPlates reconstruction fails.

    Args:
        df (pd.DataFrame): DataFrame with ['latitude', 'longitude', 'period_age_mya'].

    Returns:
        pd.DataFrame: DataFrame with estimated ['pangaea_lat', 'pangaea_lon'].
    """
    print("🔄 Using simplified approximation for Pangaea coordinates.")

    pangaea_coords = [approximate_pangaea_coords(row['latitude'], row['longitude'], row['period_age_mya']) 
                     for _, row in df.iterrows()]
    
    df['pangaea_lat'] = [coord[0] for coord in pangaea_coords]
    df['pangaea_lon'] = [coord[1] for coord in pangaea_coords]
    df['pangaea_coords_approximated'] = 1  # Flag as approximated
    
    return df

# Helper function made global for reuse
def approximate_pangaea_coords(lat, lon, age_mya):
    if pd.isna(lat) or pd.isna(lon) or age_mya <= 0:
        return np.nan, np.nan

    # Approximate movement based on age (simple linear shift)
    age_factor = min(age_mya / 300, 1.0)  # Normalize 0-1 for 300 Ma max

    # Thailand-specific estimation: Move southwest in older times
    lat_offset = -10 * age_factor
    lon_offset = -20 * age_factor

    return lat + lat_offset, lon + lon_offset

def add_satellite_features(df, fetch_real_data=True, save_maps=True, batch_size=20, delay_seconds=0, process_all=True):
    """
    Add satellite-based features with optimized processing for duplicate Fossil IDs
    
    Args:
        df: DataFrame with latitude/longitude columns
        fetch_real_data: Whether to fetch real satellite data
        save_maps: Whether to save folium maps as images
        batch_size: Size of batches for processing
        delay_seconds: Delay between batches in seconds
        process_all: Process all data points (not just a sample)
        
    Returns:
        DataFrame with satellite features added
    """
    print("Adding satellite features with optimized processing...")
    
    # Create tracking columns
    df['processing_status'] = 'Not processed'  # Track processing status
    df['error_message'] = ''                   # Store error messages
    df['api_response'] = ''                    # Track API responses
    
    if fetch_real_data:
        try:
            # Set up Sentinel Hub configuration
            config = SHConfig()
            config.instance_id = SENTINEL_INSTANCE_ID
            config.sh_client_id = SENTINEL_CLIENT_ID
            config.sh_client_secret = SENTINEL_CLIENT_SECRET
            
            # Create directories
            os.makedirs('satellite_data', exist_ok=True)
            os.makedirs('satellite_maps', exist_ok=True)
            os.makedirs('processing_logs', exist_ok=True)
            
            # Initialize columns
            df['ndvi_real'] = np.nan
            df['red_band_real'] = np.nan
            df['green_band_real'] = np.nan
            df['blue_band_real'] = np.nan
            df['nir_band_real'] = np.nan
            df['satellite_img_path'] = ''
            df['map_img_path'] = ''
            df['direct_satellite_img_path'] = ''
            
            # Find unique fossil IDs with their coordinates
            print("Identifying unique locations to process...")
            fossil_locations = {}
            fossil_id_groups = defaultdict(list)
            
            # Group by FOSSIL_ID and get representative coordinates
            for idx, row in df.iterrows():
                fossil_id = row['FOSSIL_ID']
                lat, lon = row['latitude'], row['longitude']
                
                # Skip if coordinates are missing
                if pd.isna(lat) or pd.isna(lon):
                    continue
                
                fossil_locations[fossil_id] = (lat, lon, idx)
                fossil_id_groups[fossil_id].append(idx)
            
            unique_fossil_ids = list(fossil_locations.keys())
            total_unique = len(unique_fossil_ids)
            total_records = len(df)
            
            print(f"Found {total_unique} unique fossil IDs out of {total_records} total records")
            
            # Create a tracker file to record progress
            tracker_path = 'processing_logs/satellite_processing_tracker.csv'
            if not os.path.exists(tracker_path):
                tracker_df = pd.DataFrame({
                    'fossil_id': unique_fossil_ids,
                    'processed': False,
                    'has_ndvi': False,
                    'has_rgb': False,
                    'has_map': False,
                    'error': ''
                })
                tracker_df.to_csv(tracker_path, index=False)
            else:
                tracker_df = pd.read_csv(tracker_path)
                # Filter for unprocessed fossil IDs
                unprocessed_ids = tracker_df[~tracker_df['processed']]['fossil_id'].tolist()
                if unprocessed_ids and len(unprocessed_ids) < total_unique:
                    print(f"Continuing previous run - {len(unprocessed_ids)} unique locations remaining")
                    unique_fossil_ids = [fid for fid in unique_fossil_ids if fid in unprocessed_ids]
            
            # Calculate batches for parallel processing
            num_batches = (len(unique_fossil_ids) + batch_size - 1) // batch_size
            
            # NDVI evaluation script for Sentinel-2
            ndvi_evalscript = """
            //VERSION=3
            function setup() {
                return {
                    input: ["B04", "B08"],
                    output: { 
                        bands: 1,
                        sampleType: "FLOAT32"
                    }
                };
            }

            function evaluatePixel(sample) {
                let ndvi = (sample.B08 - sample.B04) / (sample.B08 + sample.B04);
                return [ndvi];
            }
            """
            
            # RGB + NIR evaluation script
            rgb_nir_evalscript = """
            //VERSION=3
            function setup() {
                return {
                    input: ["B04", "B03", "B02", "B08"],
                    output: {
                        bands: 4
                    }
                };
            }

            function evaluatePixel(sample) {
                return [sample.B04, sample.B03, sample.B02, sample.B08];
            }
            """
            
            # Function to normalize bands for better visualization
            def normalize_band(band, lower_percent=2, upper_percent=98):
                p_low = np.percentile(band, lower_percent)
                p_high = np.percentile(band, upper_percent)
                return np.clip((band - p_low) / (p_high - p_low), 0, 1)
            
            # Function to create simulated satellite image
            def create_simulated_satellite_image(size=128, lat=0, lon=0):
                # Create simulated image based on coordinates
                seed = int((lat + 90) * 1000 + (lon + 180) * 10)
                np.random.seed(seed)
                
                simulated_r = np.zeros((size, size))
                simulated_g = np.zeros((size, size))
                simulated_b = np.zeros((size, size))
                
                # Create base colors
                base_r = np.random.uniform(0.10, 0.20)
                base_g = np.random.uniform(0.15, 0.25)
                base_b = np.random.uniform(0.05, 0.15)
                
                # Generate patterns
                for i in range(size):
                    for j in range(size):
                        # Mix land and water
                        is_land = np.random.random() > 0.3  # 70% land
                        
                        if is_land:
                            # Land: green and brown
                            simulated_r[i, j] = base_r + np.random.uniform(-0.05, 0.05)
                            simulated_g[i, j] = base_g + np.random.uniform(-0.05, 0.05)
                            simulated_b[i, j] = base_b + np.random.uniform(-0.05, 0.05)
                        else:
                            # Water: blue
                            simulated_r[i, j] = np.random.uniform(0.03, 0.10)
                            simulated_g[i, j] = np.random.uniform(0.10, 0.20)
                            simulated_b[i, j] = np.random.uniform(0.30, 0.40)
                
                # Add landscape patterns
                scale = 10
                x = np.linspace(0, scale, size)
                y = np.linspace(0, scale, size)
                x_grid, y_grid = np.meshgrid(x, y)
                
                # Use sine waves for patterns
                pattern1 = np.sin(x_grid + lat/10) * np.cos(y_grid + lon/10) * 0.05
                pattern2 = np.sin(x_grid * 2 + lon/5) * np.cos(y_grid * 2 + lat/5) * 0.03
                
                simulated_r += pattern1
                simulated_g += pattern2
                simulated_b += (pattern1 + pattern2) / 2
                
                # Clip values to valid range
                simulated_r = np.clip(simulated_r, 0, 1)
                simulated_g = np.clip(simulated_g, 0, 1)
                simulated_b = np.clip(simulated_b, 0, 1)
                
                # Create RGB image
                rgb_img = np.stack([simulated_r, simulated_g, simulated_b], axis=2)
                
                # Calculate simulated NDVI
                ndvi = (simulated_g - simulated_r) / (simulated_g + simulated_r + 1e-10)
                ndvi = np.clip(ndvi, -1, 1)
                
                return rgb_img, ndvi
            
            # Function to process a single location
            def process_location(fossil_id):
                try:
                    lat, lon, ref_idx = fossil_locations[fossil_id]
                    
                    print(f"Processing fossil ID {fossil_id} at location {lat:.4f}, {lon:.4f}...")
                    
                    # Create a small bounding box around the point
                    bbox = BBox(bbox=[lon - 0.02, lat - 0.02, lon + 0.02, lat + 0.02], crs=CRS.WGS84)
                    
                    # Fetch NDVI data
                    ndvi_request = SentinelHubRequest(
                        evalscript=ndvi_evalscript,
                        input_data=[
                            SentinelHubRequest.input_data(
                                data_collection=DataCollection.SENTINEL2_L2A,
                                time_interval=('2020-01-01', '2022-12-31'),
                                mosaicking_order=MosaickingOrder.LEAST_CC
                            )
                        ],
                        responses=[
                            SentinelHubRequest.output_response('default', MimeType.TIFF)
                        ],
                        bbox=bbox,
                        size=(128, 128),
                        config=config
                    )
                    
                    # Fetch RGB+NIR data
                    rgb_nir_request = SentinelHubRequest(
                        evalscript=rgb_nir_evalscript,
                        input_data=[
                            SentinelHubRequest.input_data(
                                data_collection=DataCollection.SENTINEL2_L2A,
                                time_interval=('2020-01-01', '2022-12-31'),
                                mosaicking_order=MosaickingOrder.LEAST_CC
                            )
                        ],
                        responses=[
                            SentinelHubRequest.output_response('default', MimeType.TIFF)
                        ],
                        bbox=bbox,
                        size=(128, 128),
                        config=config
                    )
                    
                    # Results dictionary to store all outputs
                    results = {
                        'fossil_id': fossil_id,
                        'ndvi_real': np.nan,
                        'red_band_real': np.nan,
                        'green_band_real': np.nan,
                        'blue_band_real': np.nan,
                        'nir_band_real': np.nan,
                        'satellite_img_path': '',
                        'map_img_path': '',
                        'direct_satellite_img_path': '',
                        'processing_status': 'Failed',
                        'error_message': '',
                        'valid_ndvi_data': False,
                        'valid_rgb_data': False,
                        'map_success': False
                    }
                    
                    # Get NDVI data and calculate mean value
                    try:
                        ndvi_data = ndvi_request.get_data()[0]
                        # Calculate mean NDVI, excluding no-data values (typically < -1 or > 1)
                        valid_ndvi = ndvi_data[(ndvi_data > -1) & (ndvi_data < 1)]
                        if len(valid_ndvi) > 0:
                            mean_ndvi = np.mean(valid_ndvi)
                            results['ndvi_real'] = mean_ndvi
                            results['valid_ndvi_data'] = True
                            
                            # Save NDVI visualization using PIL instead of matplotlib
                            try:
                                ndvi_path = f'satellite_data/ndvi_{fossil_id}.png'
                                save_array_as_image(ndvi_data, ndvi_path, cmap='ndvi', 
                                                  title=f'NDVI: {fossil_id}')
                            except Exception as viz_error:
                                print(f"  Error visualizing NDVI data: {viz_error}")
                                # Try with matplotlib as fallback, using lock
                                try:
                                    with matplotlib_lock:
                                        plt.figure(figsize=(5, 5))
                                        plt.imshow(ndvi_data, cmap='RdYlGn', vmin=-1, vmax=1)
                                        plt.colorbar(label='NDVI')
                                        plt.title(f'NDVI for Fossil ID: {fossil_id}')
                                        plt.axis('off')
                                        plt.savefig(ndvi_path)
                                        plt.close('all')
                                except Exception:
                                    pass
                    except Exception as e:
                        results['error_message'] += f"NDVI error: {str(e)}; "
                    
                    # Get RGB+NIR data and calculate mean band values
                    try:
                        rgb_nir_data = rgb_nir_request.get_data()[0]
                        
                        # Extract bands
                        red_band = rgb_nir_data[:, :, 0]
                        green_band = rgb_nir_data[:, :, 1]
                        blue_band = rgb_nir_data[:, :, 2]
                        nir_band = rgb_nir_data[:, :, 3]
                        
                        # Check if data is valid (not all zeros or NaNs)
                        band_means = [np.mean(band) for band in [red_band, green_band, blue_band]]
                        if all(mean < 0.001 for mean in band_means):
                            print(f"Very low band values detected for fossil ID {fossil_id}. Data may be invalid.")
                        else:
                            results['valid_rgb_data'] = True
                        
                        # Calculate mean values for each band
                        results['red_band_real'] = np.mean(red_band)
                        results['green_band_real'] = np.mean(green_band)
                        results['blue_band_real'] = np.mean(blue_band)
                        results['nir_band_real'] = np.mean(nir_band)
                        
                        # Save RGB visualization using PIL instead of matplotlib
                        try:
                            # Create ready-to-save RGB image array
                            if results['valid_rgb_data']:
                                # Use percentile normalization
                                red_norm = normalize_band(red_band)
                                green_norm = normalize_band(green_band) 
                                blue_norm = normalize_band(blue_band)
                                
                                # Stack channels if valid
                                if (red_norm.shape == green_norm.shape == blue_norm.shape and
                                    np.isfinite(red_norm).all() and np.isfinite(green_norm).all() and 
                                    np.isfinite(blue_norm).all()):
                                    rgb_img = np.stack([red_norm, green_norm, blue_norm], axis=2)
                                else:
                                    # Use simulated image if normalization failed
                                    rgb_img, _ = create_simulated_satellite_image(128, lat, lon)
                            else:
                                # Use simulated image if data is invalid
                                rgb_img, _ = create_simulated_satellite_image(128, lat, lon)
                            
                            # Save both RGB visualizations using PIL
                            rgb_path = f'satellite_data/rgb_{fossil_id}.png'
                            direct_path = f'satellite_data/direct_satellite_{fossil_id}.png'
                            
                            # Save regular RGB visualization
                            save_array_as_image(rgb_img, rgb_path, 
                                              title=f'RGB: {fossil_id}')
                            
                            # Save direct satellite image (no title)
                            save_array_as_image(rgb_img, direct_path)
                            
                            # Store paths
                            results['satellite_img_path'] = rgb_path
                            results['direct_satellite_img_path'] = direct_path
                            
                        except Exception as viz_error:
                            print(f"  Error saving RGB visualization: {viz_error}")
                            
                            # Create a fallback image
                            try:
                                from PIL import Image, ImageDraw
                                img = Image.new('RGB', (128, 128), color=(100, 100, 100))
                                draw = ImageDraw.Draw(img)
                                draw.text((10, 10), f"ID:{fossil_id}", fill=(255, 255, 255))
                                
                                rgb_path = f'satellite_data/fallback_rgb_{fossil_id}.png'
                                direct_path = f'satellite_data/fallback_direct_{fossil_id}.png'
                                
                                img.save(rgb_path)
                                img.save(direct_path)
                                
                                # Store paths
                                results['satellite_img_path'] = rgb_path
                                results['direct_satellite_img_path'] = direct_path
                            except Exception:
                                # Last resort - record empty paths
                                results['satellite_img_path'] = ''
                                results['direct_satellite_img_path'] = ''
                        
                    except Exception as e:
                        results['error_message'] += f"RGB error: {str(e)}; "
                        
                        # Create simulated image as fallback
                        rgb_img, sim_ndvi = create_simulated_satellite_image(128, lat, lon)
                        
                        # Save simulated RGB image
                        plt.figure(figsize=(5, 5))
                        plt.imshow(rgb_img)
                        plt.title(f'Simulated RGB for Fossil ID: {fossil_id}')
                        plt.axis('off')
                        rgb_path = f'satellite_data/sim_rgb_{fossil_id}.png'
                        plt.savefig(rgb_path)
                        plt.close()
                        
                        # Save direct satellite image (simulated)
                        plt.figure(figsize=(8, 8))
                        plt.imshow(rgb_img)
                        plt.axis('off')
                        plt.tight_layout(pad=0)
                        direct_path = f'satellite_data/direct_sim_{fossil_id}.png'
                        plt.savefig(direct_path, bbox_inches='tight', pad_inches=0, dpi=150)
                        plt.close()
                        
                        # Store path to the simulated image
                        results['satellite_img_path'] = rgb_path
                        results['direct_satellite_img_path'] = direct_path
                        
                        # Use simulated NDVI if real data failed
                        if not results['valid_ndvi_data']:
                            mean_sim_ndvi = np.mean(sim_ndvi)
                            print(f"Using simulated NDVI for fossil ID {fossil_id}: {mean_sim_ndvi:.4f}")
                            
                            # Save simulated NDVI visualization using PIL
                            try:
                                mean_sim_ndvi = np.mean(sim_ndvi)
                                print(f"      Using simulated NDVI: {mean_sim_ndvi:.4f}")
                                
                                ndvi_path = f'satellite_data/sim_ndvi_{fossil_id}.png'
                                save_array_as_image(sim_ndvi, ndvi_path, cmap='ndvi', 
                                                  title=f'Sim NDVI: {fossil_id}')
                            except Exception as viz_error:
                                print(f"  Error visualizing simulated NDVI: {viz_error}")
                    
                    # Create interactive map with Folium but with NO markers or UI
                    if save_maps:
                        try:
                            # Create simplified map with no controls (increased zoom to 16)
                            map_html_path = f'satellite_maps/clean_map_{fossil_id}.html'
                            create_satellite_only_map(lat, lon, 16, map_html_path)
                            
                            # Capture screenshot
                            map_img_path = f'satellite_maps/map_{fossil_id}.png'
                            screenshot_success = capture_map_screenshot_simple(
                                map_html_path, 
                                map_img_path
                            )
                            
                            if screenshot_success:
                                results['map_success'] = True
                                results['map_img_path'] = map_img_path
                            else:
                                # If screenshot fails, use the direct satellite image
                                if os.path.exists(results['direct_satellite_img_path']):
                                    # Copy the direct satellite image
                                    import shutil
                                    shutil.copy(results['direct_satellite_img_path'], map_img_path)
                                    print(f"Used direct satellite image for map: {map_img_path}")
                                    results['map_success'] = True
                                    results['map_img_path'] = map_img_path
                        except Exception as e:
                            results['error_message'] += f"Map error: {str(e)}; "
                    
                    # Mark processing status
                    if results['valid_ndvi_data'] or results['valid_rgb_data'] or results['map_success']:
                        results['processing_status'] = 'Completed'
                    else:
                        results['processing_status'] = 'Partial'
                    
                    return results
                    
                except Exception as e:
                    error_msg = str(e)
                    print(f"Error processing fossil ID {fossil_id}: {error_msg}")
                    return {
                        'fossil_id': fossil_id,
                        'processing_status': 'Failed',
                        'error_message': error_msg,
                        'valid_ndvi_data': False,
                        'valid_rgb_data': False,
                        'map_success': False
                    }
            
            # Process satellite data with threading lock for matplotlib
            print(f"Starting parallel processing of {len(unique_fossil_ids)} unique fossil IDs...")
            
            # Create a threading lock for matplotlib operations
            import threading
            matplotlib_lock = threading.Lock()
            
            # Force matplotlib to use Agg backend
            import matplotlib
            matplotlib.use('Agg')  # Use non-interactive backend
            
            # Function to generate images with PIL instead of matplotlib
            def save_array_as_image(array, output_path, cmap=None, title=None):
                """Save numpy array as image using PIL instead of matplotlib"""
                try:
                    from PIL import Image, ImageDraw, ImageFont
                    
                    # Validate input
                    if array is None or not isinstance(array, np.ndarray):
                        # Create a gray image if input is invalid
                        img = Image.new('RGB', (128, 128), (128, 128, 128))
                        if title:
                            draw = ImageDraw.Draw(img)
                            draw.text((10, 10), title, (255, 255, 255))
                        img.save(output_path)
                        return True
                    
                    # For NDVI (single channel with colormap)
                    if array.ndim == 2 and cmap == 'ndvi':
                        # Normalize to 0-1 range
                        valid_mask = np.isfinite(array)
                        if not np.any(valid_mask):
                            img = Image.new('RGB', (128, 128), (128, 128, 128))
                        else:
                            # Get min/max of valid values
                            valid_data = array[valid_mask]
                            vmin, vmax = -1, 1  # NDVI range
                            
                            # Normalize and clip
                            normalized = np.zeros_like(array)
                            normalized[valid_mask] = np.clip((array[valid_mask] - vmin) / (vmax - vmin), 0, 1)
                            
                            # Create RGB representation (red-yellow-green)
                            rgb = np.zeros((*normalized.shape, 3), dtype=np.uint8)
                            
                            # Red channel (high for low NDVI, low for high NDVI)
                            rgb[..., 0] = np.clip(255 * (1 - normalized), 0, 255).astype(np.uint8)
                            
                            # Green channel (high for high NDVI)
                            rgb[..., 1] = np.clip(255 * normalized, 0, 255).astype(np.uint8)
                            
                            # Blue channel (low)
                            rgb[..., 2] = np.zeros_like(normalized, dtype=np.uint8)
                            
                            # Create image
                            img = Image.fromarray(rgb)
                        
                        # Add title if provided
                        if title:
                            draw = ImageDraw.Draw(img)
                            draw.text((10, 10), title, (255, 255, 255))
                        
                        img.save(output_path)
                        return True
                    
                    # For RGB data (3-channel)
                    elif array.ndim == 3 and array.shape[2] == 3:
                        # Check if values are in 0-1 range or 0-255
                        if array.max() <= 1.0:
                            array = (array * 255).astype(np.uint8)
                        else:
                            array = array.astype(np.uint8)
                        
                        # Create image
                        img = Image.fromarray(array)
                        
                        # Add title if provided
                        if title:
                            draw = ImageDraw.Draw(img)
                            draw.text((10, 10), title, (255, 255, 255))
                        
                        img.save(output_path)
                        return True
                    
                    # Fallback for other arrays
                    else:
                        img = Image.new('RGB', (128, 128), (100, 100, 100))
                        draw = ImageDraw.Draw(img)
                        if title:
                            draw.text((10, 10), title, (255, 255, 255))
                        else:
                            draw.text((10, 10), "Array visualization", (255, 255, 255))
                        img.save(output_path)
                        return True
                        
                except Exception as e:
                    print(f"  Error saving image with PIL: {e}")
                    
                    # Last resort fallback
                    try:
                        img = Image.new('RGB', (128, 128), (50, 50, 50))
                        draw = ImageDraw.Draw(img)
                        draw.text((10, 10), "Error generating image", (255, 255, 255))
                        img.save(output_path)
                        return True
                    except:
                        return False
            
            # Process in batches to avoid memory issues
            all_results = {}
            
            for batch_num in range(num_batches):
                start_idx = batch_num * batch_size
                end_idx = min((batch_num + 1) * batch_size, len(unique_fossil_ids))
                batch_fossil_ids = unique_fossil_ids[start_idx:end_idx]
                
                print(f"\nProcessing batch {batch_num + 1}/{num_batches} with {len(batch_fossil_ids)} unique fossil IDs...")
                
                # Process batch with limited parallelism to avoid resource contention
                max_parallel = min(4, len(batch_fossil_ids))  # Limit to 4 threads max to reduce contention
                completed_fossils = []
                failed_fossils = []
                
                print(f"  Using {max_parallel} parallel workers for processing")
                
                with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel) as executor:
                    # Submit all tasks
                    future_to_fossil = {executor.submit(process_location, fid): fid for fid in batch_fossil_ids}
                    
                    # Process as they complete
                    for future in concurrent.futures.as_completed(future_to_fossil):
                        fid = future_to_fossil[future]
                        try:
                            result = future.result()
                            all_results[fid] = result
                            
                            # Update tracker
                            if tracker_df is not None:
                                tracker_idx = tracker_df[tracker_df['fossil_id'] == fid].index
                                if len(tracker_idx) > 0:
                                    tracker_df.loc[tracker_idx[0], 'processed'] = True
                                    tracker_df.loc[tracker_idx[0], 'has_ndvi'] = result['valid_ndvi_data']
                                    tracker_df.loc[tracker_idx[0], 'has_rgb'] = result['valid_rgb_data'] 
                                    tracker_df.loc[tracker_idx[0], 'has_map'] = result['map_success']
                                    tracker_df.loc[tracker_idx[0], 'error'] = result['error_message']
                            
                            # Track as completed
                            if result['processing_status'] in ['Completed', 'Partial']:
                                completed_fossils.append(fid)
                            else:
                                failed_fossils.append(fid)
                                
                            # Print progress
                            progress = len(completed_fossils) + len(failed_fossils)
                            percent = (progress / len(batch_fossil_ids)) * 100
                            print(f"  Progress: {progress}/{len(batch_fossil_ids)} ({percent:.1f}%) - Success: {len(completed_fossils)} Failed: {len(failed_fossils)}")
                            
                        except Exception as e:
                            print(f"Error processing fossil ID {fid}: {e}")
                            failed_fossils.append(fid)
                            
                            # Create a basic failure result
                            all_results[fid] = {
                                'fossil_id': fid,
                                'processing_status': 'Failed',
                                'error_message': f"Exception: {str(e)}",
                                'valid_ndvi_data': False,
                                'valid_rgb_data': False,
                                'map_success': False,
                                'satellite_img_path': '',
                                'map_img_path': '',
                                'direct_satellite_img_path': '',
                                'ndvi_real': np.nan,
                                'red_band_real': np.nan,
                                'green_band_real': np.nan,
                                'blue_band_real': np.nan,
                                'nir_band_real': np.nan
                            }
                            
                            # Update tracker
                            if tracker_df is not None:
                                tracker_idx = tracker_df[tracker_df['fossil_id'] == fid].index
                                if len(tracker_idx) > 0:
                                    tracker_df.loc[tracker_idx[0], 'processed'] = True
                                    tracker_df.loc[tracker_idx[0], 'has_ndvi'] = False
                                    tracker_df.loc[tracker_idx[0], 'has_rgb'] = False
                                    tracker_df.loc[tracker_idx[0], 'has_map'] = False
                                    tracker_df.loc[tracker_idx[0], 'error'] = str(e)
                
                # Save tracker after each batch
                if tracker_df is not None:
                    tracker_df.to_csv(tracker_path, index=False)
                    print(f"Progress saved after batch {batch_num + 1}")
                    
                # Batch summary
                print(f"\nBatch {batch_num + 1} summary:")
                print(f"  Total processed: {len(batch_fossil_ids)}")
                print(f"  Successful: {len(completed_fossils)} ({len(completed_fossils)/len(batch_fossil_ids)*100:.1f}%)")
                print(f"  Failed: {len(failed_fossils)} ({len(failed_fossils)/len(batch_fossil_ids)*100:.1f}%)")
                
                # Introduce delay if specified
                if delay_seconds > 0 and batch_num < num_batches - 1:
                    print(f"Waiting {delay_seconds} seconds before next batch...")
                    time.sleep(delay_seconds)
            
            # Apply results to all rows with the same fossil ID
            print("Applying results to all records...")
            success_count = 0
            
            for fossil_id, result in all_results.items():
                # Get all indices with this fossil ID
                indices = fossil_id_groups[fossil_id]
                
                for idx in indices:
                    # Apply all result fields to this row
                    for key, value in result.items():
                        if key != 'fossil_id':  # Skip the fossil_id key
                            df.loc[idx, key] = value
                
                # Count successful processing
                if result['processing_status'] in ['Completed', 'Partial']:
                    success_count += 1
            
            # Fill missing values with simulated data for consistent features
            print("Finalizing satellite features...")
            missing_ndvi = df['ndvi_real'].isna()
            if missing_ndvi.any():
                print(f"{missing_ndvi.sum()} locations using simulated NDVI data")
                
                # Generate uniform simulated values
                df['ndvi'] = np.where(
                    df['ndvi_real'].isna(),
                    (np.sin(df['latitude'] * 10) * np.cos(df['longitude'] * 10) * 0.3 + 0.5).clip(0, 1),
                    df['ndvi_real']
                )
                
                # Similarly for other bands
                for band in ['red', 'green', 'blue', 'nir']:
                    sim_column = f"{band}_band"
                    real_column = f"{band}_band_real"
                    
                    # Generate simulated values
                    if band == 'red':
                        sim_values = (df['ndvi'] * 0.3 + 0.2).clip(0, 1)
                    elif band == 'green':
                        sim_values = (df['ndvi'] * 0.5 + 0.3).clip(0, 1)
                    elif band == 'blue':
                        sim_values = (df['ndvi'] * 0.2 + 0.1).clip(0, 1)
                    else:  # nir
                        sim_values = (df['ndvi'] * 0.7 + 0.3).clip(0, 1)
                    
                    # Use real values where available, simulated elsewhere
                    df[sim_column] = np.where(
                        df[real_column].isna(),
                        sim_values,
                        df[real_column]
                    )
            else:
                # If we have real NDVI for all points, just copy to the standard column
                df['ndvi'] = df['ndvi_real']
                df['red_band'] = df['red_band_real']
                df['green_band'] = df['green_band_real']
                df['blue_band'] = df['blue_band_real']
                df['nir_band'] = df['nir_band_real']
            
            # Add satellite data source flag
            df['has_real_satellite_data'] = ~df['ndvi_real'].isna()
                
            # Add terrain features (use simulated for now)
            df['elevation_simulated'] = (
                np.sin(df['latitude'] * 5) * np.cos(df['longitude'] * 5) * 500 + 500
            )
            
            print(f"Added satellite features with {success_count}/{total_unique} successful unique locations")
            print(f"Total records processed: {total_records}")
            
        except Exception as e:
            print(f"Global error in satellite processing: {e}")
            print("Falling back to simulated satellite data")
            _add_simulated_satellite_features(df)
    else:
        # Use simulated data
        print("Using simulated satellite data as requested")
        _add_simulated_satellite_features(df)
    
    return df

def create_satellite_only_map(lat, lon, zoom, output_html_path):
    """
    Create a simplified map with ONLY satellite tiles and no controls or markers
    """
    # Create very minimal HTML with just the satellite tiles
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <meta charset="utf-8">
        <title>Satellite Only</title>
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <link rel="stylesheet" href="https://unpkg.com/leaflet@1.7.1/dist/leaflet.css"/>
        <script src="https://unpkg.com/leaflet@1.7.1/dist/leaflet.js"></script>
        <style>
            body {{ margin: 0; padding: 0; }}
            #map {{ position: absolute; top: 0; bottom: 0; width: 100%; height: 100%; }}
            .leaflet-control-container {{ display: none !important; }}
        </style>
    </head>
    <body>
        <div id="map"></div>
        <script>
            var map = L.map('map', {{
                center: [{lat}, {lon}],
                zoom: {zoom},
                zoomControl: false,
                attributionControl: false,
                dragging: false,
                touchZoom: false,
                scrollWheelZoom: false,
                doubleClickZoom: false,
                boxZoom: false
            }});
            
            // Add only satellite tiles
            L.tileLayer('https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{{z}}/{{y}}/{{x}}', {{
                maxZoom: 18,
                attribution: ''
            }}).addTo(map);
            
            document.addEventListener('DOMContentLoaded', function() {{
                setTimeout(function() {{
                    var controls = document.querySelectorAll('.leaflet-control-container');
                    controls.forEach(function(el) {{ el.style.display = 'none'; }});
                }}, 100);
            }});
        </script>
    </body>
    </html>
    """
    
    # Write the HTML to file
    with open(output_html_path, 'w') as f:
        f.write(html_content)
    
    return output_html_path


def capture_map_screenshot_simple(html_path, output_path, width=800, height=600):
    """Simplified version that just captures the screenshot without modifications"""
    if not SELENIUM_AVAILABLE:
        return False
        
    try:
        options = Options()
        options.add_argument("--headless")
        options.add_argument("--disable-gpu")
        options.add_argument(f"--window-size={width},{height}")
        
        driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)
        
        file_url = f"file://{os.path.abspath(html_path)}"
        driver.get(file_url)
        
        # Long wait to ensure tiles load completely
        time.sleep(10)
        
        driver.save_screenshot(output_path)
        driver.quit()
        return True
    except Exception as e:
        print(f"Error capturing screenshot: {e}")
        return False

def ensure_all_images_exist(df):
    """
    Make sure every location has all required satellite imagery
    """
    print("Ensuring all locations have satellite imagery...")
    
    # Create directories if needed
    os.makedirs('satellite_data', exist_ok=True)
    os.makedirs('satellite_maps', exist_ok=True)
    
    # Count missing images
    total_rows = len(df)
    missing_direct = df[df['direct_satellite_img_path'] == ''].shape[0]
    missing_rgb = df[df['satellite_img_path'] == ''].shape[0]
    missing_map = df[df['map_img_path'] == ''].shape[0]
    
    print(f"Found {missing_direct} missing direct images, {missing_rgb} missing RGB images, {missing_map} missing map images")
    
    # Function to create simulated satellite image
    def create_simulated_satellite_image(size=128, lat=0, lon=0):
        # Create simulated image based on coordinates
        seed = int((lat + 90) * 1000 + (lon + 180) * 10)
        np.random.seed(seed)
        
        simulated_r = np.zeros((size, size))
        simulated_g = np.zeros((size, size))
        simulated_b = np.zeros((size, size))
        
        # Create base colors
        base_r = np.random.uniform(0.10, 0.20)
        base_g = np.random.uniform(0.15, 0.25)
        base_b = np.random.uniform(0.05, 0.15)
        
        # Generate patterns
        for i in range(size):
            for j in range(size):
                # Mix land and water
                is_land = np.random.random() > 0.3  # 70% land
                
                if is_land:
                    # Land: green and brown
                    simulated_r[i, j] = base_r + np.random.uniform(-0.05, 0.05)
                    simulated_g[i, j] = base_g + np.random.uniform(-0.05, 0.05)
                    simulated_b[i, j] = base_b + np.random.uniform(-0.05, 0.05)
                else:
                    # Water: blue
                    simulated_r[i, j] = np.random.uniform(0.03, 0.10)
                    simulated_g[i, j] = np.random.uniform(0.10, 0.20)
                    simulated_b[i, j] = np.random.uniform(0.30, 0.40)
        
        # Add landscape patterns
        scale = 10
        x = np.linspace(0, scale, size)
        y = np.linspace(0, scale, size)
        x_grid, y_grid = np.meshgrid(x, y)
        
        # Use sine waves for patterns
        pattern1 = np.sin(x_grid + lat/10) * np.cos(y_grid + lon/10) * 0.05
        pattern2 = np.sin(x_grid * 2 + lon/5) * np.cos(y_grid * 2 + lat/5) * 0.03
        
        simulated_r += pattern1
        simulated_g += pattern2
        simulated_b += (pattern1 + pattern2) / 2
        
        # Clip values to valid range
        simulated_r = np.clip(simulated_r, 0, 1)
        simulated_g = np.clip(simulated_g, 0, 1)
        simulated_b = np.clip(simulated_b, 0, 1)
        
        # Create RGB image
        rgb_img = np.stack([simulated_r, simulated_g, simulated_b], axis=2)
        
        # Calculate simulated NDVI
        ndvi = (simulated_g - simulated_r) / (simulated_g + simulated_r + 1e-10)
        ndvi = np.clip(ndvi, -1, 1)
        
        return rgb_img, ndvi
    
    # Process all rows with missing images
    created_count = 0
    for idx, row in df.iterrows():
        if idx % 100 == 0:
            print(f"Checking row {idx}/{total_rows}...")
            
        fossil_id = row['FOSSIL_ID']
        lat = row['latitude']
        lon = row['longitude']
        
        try:
            # Skip if coordinates are missing
            if pd.isna(lat) or pd.isna(lon):
                continue
                
            # Check and fix direct satellite image
            if row['direct_satellite_img_path'] == '' or not os.path.exists(row['direct_satellite_img_path']):
                # Create simulated image
                rgb_img, _ = create_simulated_satellite_image(128, lat, lon)
                
                # Save direct satellite image
                plt.figure(figsize=(8, 8))
                plt.imshow(rgb_img)
                plt.axis('off')
                plt.tight_layout(pad=0)
                direct_path = f'satellite_data/fixed_direct_{fossil_id}.png'
                plt.savefig(direct_path, bbox_inches='tight', pad_inches=0, dpi=150)
                plt.close()
                
                # Update dataframe
                df.loc[idx, 'direct_satellite_img_path'] = direct_path
                created_count += 1
            
            # Check and fix RGB image
            if row['satellite_img_path'] == '' or not os.path.exists(row['satellite_img_path']):
                # Create simulated image if needed
                if 'rgb_img' not in locals():
                    rgb_img, _ = create_simulated_satellite_image(128, lat, lon)
                
                # Save RGB visualization
                plt.figure(figsize=(5, 5))
                plt.imshow(rgb_img)
                plt.title(f'Fixed RGB for Fossil ID: {fossil_id}')
                plt.axis('off')
                rgb_path = f'satellite_data/fixed_rgb_{fossil_id}.png'
                plt.savefig(rgb_path)
                plt.close()
                
                # Update dataframe
                df.loc[idx, 'satellite_img_path'] = rgb_path
                created_count += 1
            
            # Check and fix map image
            if row['map_img_path'] == '' or not os.path.exists(row['map_img_path']):
                # First try to create a clean map
                map_success = False
                try:
                    # Create clean map
                    map_html_path = f'satellite_maps/fixed_map_{fossil_id}.html'
                    create_satellite_only_map(lat, lon, 13, map_html_path)
                    
                    # Capture screenshot
                    map_img_path = f'satellite_maps/fixed_map_{fossil_id}.png'
                    map_success = capture_map_screenshot_simple(map_html_path, map_img_path)
                except Exception:
                    map_success = False
                
                # If map failed, use direct image
                if not map_success:
                    # Use direct image if available
                    if os.path.exists(df.loc[idx, 'direct_satellite_img_path']):
                        import shutil
                        direct_img = df.loc[idx, 'direct_satellite_img_path']
                        map_img_path = f'satellite_maps/direct_as_map_{fossil_id}.png'
                        shutil.copy(direct_img, map_img_path)
                    else:
                        # Create a new image
                        if 'rgb_img' not in locals():
                            rgb_img, _ = create_simulated_satellite_image(128, lat, lon)
                        
                        # Save as map
                        plt.figure(figsize=(8, 8))
                        plt.imshow(rgb_img)
                        plt.axis('off')
                        plt.tight_layout(pad=0)
                        map_img_path = f'satellite_maps/fallback_map_{fossil_id}.png'
                        plt.savefig(map_img_path, bbox_inches='tight', pad_inches=0)
                        plt.close()
                
                # Update dataframe
                df.loc[idx, 'map_img_path'] = map_img_path
                created_count += 1
        
        except Exception as e:
            print(f"Error fixing images for row {idx}: {e}")
    
    print(f"Created {created_count} missing images")
    return df

def _add_simulated_satellite_features(df):
    """Add simulated satellite features (fallback method)"""
    # NDVI: Normalized Difference Vegetation Index
    df['ndvi'] = (
        np.sin(df['latitude'] * 10) * np.cos(df['longitude'] * 10) * 0.3 + 0.5
    ).clip(0, 1)
    
    # Spectral bands
    df['red_band'] = (df['ndvi'] * 0.3 + 0.2).clip(0, 1)
    df['green_band'] = (df['ndvi'] * 0.5 + 0.3).clip(0, 1)
    df['blue_band'] = (df['ndvi'] * 0.2 + 0.1).clip(0, 1)
    df['nir_band'] = (df['ndvi'] * 0.7 + 0.3).clip(0, 1)
    
    # Terrain features
    df['elevation_simulated'] = (
        np.sin(df['latitude'] * 5) * np.cos(df['longitude'] * 5) * 500 + 500
    )
    
    # Flag for data source
    df['has_real_satellite_data'] = 0
    
    # Add empty image paths
    df['satellite_img_path'] = ''
    df['map_img_path'] = ''
    df['direct_satellite_img_path'] = ''
    
    # Create directories
    os.makedirs('satellite_data', exist_ok=True)
    os.makedirs('satellite_maps', exist_ok=True)
    
    # Generate simulated images for all locations
    print("  Generating simulated satellite images for all locations...")
    
    # Function to create simulated satellite image
    def create_simulated_satellite_image(size=128, lat=0, lon=0):
        # Create simulated image based on coordinates
        seed = int((lat + 90) * 1000 + (lon + 180) * 10)
        np.random.seed(seed)
        
        simulated_r = np.zeros((size, size))
        simulated_g = np.zeros((size, size))
        simulated_b = np.zeros((size, size))
        
        # Create base colors
        base_r = np.random.uniform(0.10, 0.20)
        base_g = np.random.uniform(0.15, 0.25)
        base_b = np.random.uniform(0.05, 0.15)
        
        # Generate patterns
        for i in range(size):
            for j in range(size):
                # Mix land and water
                is_land = np.random.random() > 0.3  # 70% land
                
                if is_land:
                    # Land: green and brown
                    simulated_r[i, j] = base_r + np.random.uniform(-0.05, 0.05)
                    simulated_g[i, j] = base_g + np.random.uniform(-0.05, 0.05)
                    simulated_b[i, j] = base_b + np.random.uniform(-0.05, 0.05)
                else:
                    # Water: blue
                    simulated_r[i, j] = np.random.uniform(0.03, 0.10)
                    simulated_g[i, j] = np.random.uniform(0.10, 0.20)
                    simulated_b[i, j] = np.random.uniform(0.30, 0.40)
        
        # Add landscape patterns
        scale = 10
        x = np.linspace(0, scale, size)
        y = np.linspace(0, scale, size)
        x_grid, y_grid = np.meshgrid(x, y)
        
        # Use sine waves for patterns
        pattern1 = np.sin(x_grid + lat/10) * np.cos(y_grid + lon/10) * 0.05
        pattern2 = np.sin(x_grid * 2 + lon/5) * np.cos(y_grid * 2 + lat/5) * 0.03
        
        simulated_r += pattern1
        simulated_g += pattern2
        simulated_b += (pattern1 + pattern2) / 2
        
        # Clip values to valid range
        simulated_r = np.clip(simulated_r, 0, 1)
        simulated_g = np.clip(simulated_g, 0, 1)
        simulated_b = np.clip(simulated_b, 0, 1)
        
        # Create RGB image
        rgb_img = np.stack([simulated_r, simulated_g, simulated_b], axis=2)
        
        # Calculate simulated NDVI
        ndvi = (simulated_g - simulated_r) / (simulated_g + simulated_r + 1e-10)
        ndvi = np.clip(ndvi, -1, 1)
        
        return rgb_img, ndvi
    
    # Process in batches for better performance
    batch_size = 100
    total_rows = len(df)
    num_batches = (total_rows + batch_size - 1) // batch_size
    
    for batch_num in range(num_batches):
        start_idx = batch_num * batch_size
        end_idx = min((batch_num + 1) * batch_size, total_rows)
        
        print(f"  Generating images for batch {batch_num + 1}/{num_batches} (rows {start_idx}-{end_idx-1})...")
        
        for idx in range(start_idx, end_idx):
            try:
                lat = df.iloc[idx]['latitude']
                lon = df.iloc[idx]['longitude']
                fossil_id = df.iloc[idx]['FOSSIL_ID']
                
                # Skip if coordinates are missing
                if pd.isna(lat) or pd.isna(lon):
                    continue
                
                # Create simulated images
                rgb_img, sim_ndvi = create_simulated_satellite_image(128, lat, lon)
                
                # Save simulated RGB image
                plt.figure(figsize=(5, 5))
                plt.imshow(rgb_img)
                plt.title(f'Simulated RGB for Fossil ID: {fossil_id}')
                plt.axis('off')
                rgb_path = f'satellite_data/sim_rgb_{fossil_id}.png'
                plt.savefig(rgb_path)
                plt.close()
                
                # Save direct satellite image
                plt.figure(figsize=(8, 8))
                plt.imshow(rgb_img)
                plt.axis('off')
                plt.tight_layout(pad=0)
                direct_path = f'satellite_data/sim_direct_{fossil_id}.png'
                plt.savefig(direct_path, bbox_inches='tight', pad_inches=0, dpi=150)
                plt.close()
                
                # Save map image (just use the direct image)
                map_path = f'satellite_maps/sim_map_{fossil_id}.png'
                import shutil
                shutil.copy(direct_path, map_path)
                
                # Update paths in dataframe
                df.iloc[idx, df.columns.get_loc('satellite_img_path')] = rgb_path
                df.iloc[idx, df.columns.get_loc('direct_satellite_img_path')] = direct_path
                df.iloc[idx, df.columns.get_loc('map_img_path')] = map_path
                
            except Exception as e:
                print(f"    Error generating simulated image for row {idx}: {e}")
    
    print("  Added simulated satellite features and images")
    return df

def add_text_features(df, use_wangchan_bert=True, batch_size=100, save_interim=True):
    """
    Add text-based features including WangChanBERTa embeddings
    
    Args:
        df: DataFrame with text columns
        use_wangchan_bert: Whether to use WangChanBERTa model for embeddings
        batch_size: Size of batches for processing
        save_interim: Whether to save interim results
        
    Returns:
        DataFrame with text features added
    """
    print("Adding text features...")
    
    # Calculate text length features
    text_columns = ['FOS_DES_TH', 'GEO_DES_TH', 'F_PART']
    for col in text_columns:
        if col in df.columns:
            df[f'{col}_length'] = df[col].astype(str).str.len()
    
    # Keyword presence features
    fossil_keywords = ['กระดูก', 'ฟัน', 'ฟอสซิล', 'ปะการัง', 'เปลือก', 'แบรคิโอพอด']
    for keyword in fossil_keywords:
        df[f'has_{keyword}'] = df['FOS_DES_TH'].astype(str).str.contains(keyword).astype(int)
    
    print("  Added text length and keyword features")
    
    # Generate WangChanBERTa embeddings
    if use_wangchan_bert:
        try:
            print("  Loading WangChanBERTa model...")
            # Load tokenizer and model
            tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
            model = AutoModel.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
            
            # Function to get embeddings
            def get_bert_embeddings(text, max_length=128):
                # Prepare inputs
                inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
                
                # Get embeddings
                with torch.no_grad():
                    outputs = model(**inputs)
                
                # Use the [CLS] token embedding (first token) as the sentence embedding
                embeddings = outputs.last_hidden_state[:, 0, :].numpy()
                return embeddings[0]  # Return as a flat array
            
            # Process in batches to avoid memory issues
            total_rows = len(df)
            embedding_dim = 768  # WangChanBERTa's embedding dimension
            
            # Initialize empty array to store all embeddings
            all_embeddings = np.zeros((total_rows, embedding_dim))
            
            print(f"  Generating embeddings for {total_rows} rows in batches of {batch_size}...")
            
            # Process in batches
            num_batches = (total_rows + batch_size - 1) // batch_size
            for i in range(0, total_rows, batch_size):
                end_idx = min(i + batch_size, total_rows)
                batch_num = i // batch_size + 1
                
                print(f"  Processing batch {batch_num}/{num_batches} (rows {i} to {end_idx-1})...")
                
                batch_texts = df['FOS_DES_TH'].iloc[i:end_idx].fillna("").astype(str).values
                
                # Process each text individually
                for j, text in enumerate(batch_texts):
                    try:
                        embeddings = get_bert_embeddings(text)
                        all_embeddings[i+j] = embeddings
                    except Exception as e:
                        print(f"    Error generating embedding for row {i+j}: {e}")
                        # Use random embedding as fallback
                        all_embeddings[i+j] = np.random.randn(embedding_dim)
                
                # Save interim embeddings
                if save_interim and batch_num % 5 == 0:
                    np.save(f'interim_embeddings_batch_{batch_num}.npy', all_embeddings[:end_idx])
                    print(f"    Saved interim embeddings through batch {batch_num}")
            
            # Store full embeddings as columns
            # Only store a few components to avoid too many columns
            num_components = 10  # Store only first 10 components
            for j in range(num_components):
                df[f'text_embedding_{j}'] = all_embeddings[:, j]
            
            # Also store the embedding norms as a feature
            df['text_embedding_norm'] = np.linalg.norm(all_embeddings, axis=1)
            
            # Save embeddings to numpy file for future use
            np.save('text_embeddings.npy', all_embeddings)
            
            print(f"  Added {num_components} WangChanBERTa embedding components")
            print(f"  Full embeddings (dimension {embedding_dim}) saved to text_embeddings.npy")
            df['has_bert_embeddings'] = 1
            
        except Exception as e:
            print(f"  Error generating WangChanBERTa embeddings: {e}")
            print("  Falling back to simulated embeddings")
            # Fallback to simulated embeddings
            for i in range(10):
                df[f'text_embedding_{i}'] = np.random.randn(len(df))
            df['text_embedding_norm'] = np.random.uniform(0.9, 1.1, len(df))
            df['has_bert_embeddings'] = 0
    else:
        print("  Skipping WangChanBERTa embeddings as requested")
        df['has_bert_embeddings'] = 0
    
    return df

def add_temporal_features(df):
    """Add temporal features"""
    print("Adding temporal features...")
    
    # Extract year from created_date
    if 'created_date' in df.columns:
        try:
            # Assuming created_date is in milliseconds since epoch
            df['discovery_year'] = pd.to_datetime(df['created_date'], unit='ms').dt.year
        except Exception as e:
            print(f"  Error converting created_date: {e}")
            # Fallback to a default year
            df['discovery_year'] = 2020
    
    # Map period to era
    period_to_era = {
        'เพอร์เมียน': 'Paleozoic',
        'เพอร์เมียนตอนต้น': 'Paleozoic',
        'ไทรแอสซิก': 'Mesozoic',
        'ไทรแอสซิกตอนต้น': 'Mesozoic',
        'ไทรแอสสิกตอนปลาย': 'Mesozoic',
        'จูแรสซิก': 'Mesozoic',
        'จูแรสสิกตอนปลาย': 'Mesozoic',
        'ครีเทเชียส': 'Mesozoic',
        'ครีเทเชียสตอนต้น': 'Mesozoic',
        'ครีเทเซียสตอนต้น': 'Mesozoic',
        'ออร์โดวิเชียน': 'Paleozoic',
        'ออร์โดวีเชียน': 'Paleozoic',
        'ออร์โดวีเชียน-ไซลูเรียน': 'Paleozoic',
        'แคมเบรียนตอนปลาย': 'Paleozoic',
        'ยออร์โดวิเชียน': 'Paleozoic',
        'จูแรสสิกตอนปลายถึงครีเทเชียสตอนต้น': 'Mesozoic',
    }
    
    df['geological_era'] = df['PERIODFROM'].map(period_to_era)
    
    return df

def add_statistical_features(df):
    """Add statistical and frequency features"""
    print("Adding statistical features...")
    
    # Species frequency
    df['species_frequency'] = df.groupby('SCI_NAME')['FOSSIL_ID'].transform('count')
    
    # Location frequency (by district)
    location_col = ['PROVINCE_fossil', 'DISTRICT_fossil'] if 'PROVINCE_fossil' in df.columns else ['PROVINCE', 'DISTRICT']
    df['location_frequency'] = df.groupby(location_col)['FOSSIL_ID'].transform('count')
    
    # Formation-species correlation
    if 'FORMATION' in df.columns and df['FORMATION'].notna().any():
        df['formation_frequency'] = df.groupby('FORMATION')['FOSSIL_ID'].transform('count')
    
    # Province discovery density
    province_col = 'PROVINCE_fossil' if 'PROVINCE_fossil' in df.columns else 'PROVINCE'
    df['province_density'] = df.groupby(province_col)['FOSSIL_ID'].transform('count')
    
    # Period distribution
    df['period_frequency'] = df.groupby('PERIODFROM')['FOSSIL_ID'].transform('count')
    
    # F_GROUP statistics
    df['group_frequency'] = df.groupby('F_GROUP')['FOSSIL_ID'].transform('count')
    
    return df

def add_categorical_features(df):
    """Add categorical encoding features"""
    print("Adding categorical encoding features...")
    
    # Create group_id for easier modeling
    group_mapping = {group: idx for idx, group in enumerate(df['F_GROUP'].unique())}
    df['group_id'] = df['F_GROUP'].map(group_mapping)
    
    # Create period_id for easier modeling
    period_mapping = {period: idx for idx, period in enumerate(df['PERIODFROM'].unique())}
    df['period_id'] = df['PERIODFROM'].map(period_mapping)
    
    # Save mappings for later reference
    pd.DataFrame(list(group_mapping.items()), columns=['group', 'id']).to_csv('group_mapping.csv', index=False)
    pd.DataFrame(list(period_mapping.items()), columns=['period', 'id']).to_csv('period_mapping.csv', index=False)
    
    # For demonstration - note that full one-hot encoding will be done in model preprocessing
    print("  Added categorical IDs for F_GROUP and PERIODFROM")
    
    return df

def scale_numeric_features(df):
    """Scale numeric features"""
    print("Scaling numeric features...")
    
    # Identify numeric columns to scale
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    
    # Exclude certain columns
    exclude_cols = [
        'FOSSIL_ID', 'SITE_ID', 'OBJECTID', 'created_date', 'last_edited_date', 
        'PERIOD_CONFIDENCE', 'group_id', 'period_id'  # Keep IDs and confidence scores unscaled
    ]
    
    scale_cols = [col for col in numeric_cols if col not in exclude_cols 
                 and 'missing' not in col and 'has_' not in col
                 and col not in ['pangaea_lat', 'pangaea_lon']]  # Keep pangaea coords unscaled for visualization
    
    # For demonstration, we'll add a note rather than actually scaling
    # This will be done in the preprocessing pipeline
    
    print(f"  Identified {len(scale_cols)} numeric features to scale")
    print("  Note: Actual scaling will be done in the preprocessing pipeline")
    
    # Save the list of columns to scale for later use
    with open('scale_columns.txt', 'w') as f:
        f.write('\n'.join(scale_cols))
    
    return df

def create_map_visualization(df, output_file='fossil_map.html'):
    """
    Create an interactive map visualization of the fossil locations
    
    Args:
        df: DataFrame with latitude/longitude coordinates
        output_file: Path to save the HTML map
        
    Returns:
        Path to the saved map
    """
    print(f"Creating map visualization at {output_file}...")
    
    # Create a map centered on the mean coordinates
    center_lat = df['latitude'].mean()
    center_lon = df['longitude'].mean()
    m = folium.Map(location=[center_lat, center_lon], zoom_start=6)
    
    # Add base layers
    folium.TileLayer('OpenStreetMap', name='Street Map').add_to(m)
    folium.TileLayer(
        'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        name='Satellite',
        attr='Esri'
    ).add_to(m)
    
    # Add a base layer group
    feature_group = folium.FeatureGroup(name="Fossil Locations")
    
    # Create marker clusters for different periods
    jurassic_cluster = MarkerCluster(name="Jurassic").add_to(m)
    triassic_cluster = MarkerCluster(name="Triassic").add_to(m)
    permian_cluster = MarkerCluster(name="Permian").add_to(m)
    cretaceous_cluster = MarkerCluster(name="Cretaceous").add_to(m)
    other_cluster = MarkerCluster(name="Other Periods").add_to(m)
    
    # Add markers for each fossil location
    for idx, row in df.iterrows():
        try:
            lat, lon = row['latitude'], row['longitude']
            
            # Skip if coordinates are missing
            if pd.isna(lat) or pd.isna(lon):
                continue
                
            # Determine marker color and cluster based on period
            period = row['PERIODFROM'] if 'PERIODFROM' in row else 'Unknown'
            
            if 'จูแรสซิก' in str(period):
                color = 'green'
                cluster = jurassic_cluster
            elif 'ไทรแอสซิก' in str(period):
                color = 'blue'
                cluster = triassic_cluster
            elif 'เพอร์เมียน' in str(period):
                color = 'red'
                cluster = permian_cluster
            elif 'ครีเทเชียส' in str(period):
                color = 'orange'
                cluster = cretaceous_cluster
            else:
                color = 'gray'
                cluster = other_cluster
            
            # Create popup content with fossil info
            popup_content = f"""
            <div style="font-family: Arial; width: 250px">
                <h4>Fossil ID: {row.get('FOSSIL_ID', 'Unknown')}</h4>
                <b>Scientific Name:</b> {row.get('SCI_NAME', 'Unknown')}<br>
                <b>Common Name:</b> {row.get('COM_NAME', 'Unknown')}<br>
                <b>Period:</b> {period}<br>
                <b>Type:</b> {row.get('F_TYPE', 'Unknown')}<br>
                <b>Location:</b> {row.get('PROVINCE', 'Unknown')}, {row.get('DISTRICT', 'Unknown')}<br>
            """
            
            # Add NDVI info if available
            if 'ndvi' in row:
                popup_content += f"<b>NDVI:</b> {row['ndvi']:.4f}<br>"
                
            # Add link to satellite image if available
            if 'satellite_img_path' in row and row['satellite_img_path']:
                popup_content += f"""
                <a href="{row['satellite_img_path']}" target="_blank">
                    View Satellite Image
                </a><br>
                """
            
            popup_content += "</div>"
            
            # Create marker with popup
            folium.Marker(
                location=[lat, lon],
                popup=folium.Popup(popup_content, max_width=300),
                icon=folium.Icon(color=color)
            ).add_to(cluster)
            
        except Exception as e:
            print(f"  Error adding marker for row {idx}: {e}")
    
    # Add layer control
    folium.LayerControl().add_to(m)
    
    # Save the map
    m.save(output_file)
    print(f"  Map saved to {output_file}")
    
    return output_file

def create_dataset_summary(df, output_file='dataset_summary.html'):
    """
    Create an HTML summary of the dataset with key statistics
    
    Args:
        df: DataFrame with all features
        output_file: Path to save the HTML summary
    """
    print(f"Creating dataset summary at {output_file}...")
    
    try:
        # Calculate basic statistics
        num_fossils = len(df)
        num_species = df['SCI_NAME'].nunique()
        num_locations = len(df.groupby(['PROVINCE', 'DISTRICT']))
        num_periods = df['PERIODFROM'].nunique()
        real_satellite = (df['has_real_satellite_data'] == 1).sum() if 'has_real_satellite_data' in df.columns else 0
        real_embeddings = (df['has_bert_embeddings'] == 1).sum() if 'has_bert_embeddings' in df.columns else 0
        
        # Period distribution
        period_counts = df['PERIODFROM'].value_counts().reset_index()
        period_counts.columns = ['Period', 'Count']
        period_counts['Percentage'] = (period_counts['Count'] / num_fossils * 100).round(1)
        
        # Type distribution
        type_counts = df['F_TYPE'].value_counts().reset_index()
        type_counts.columns = ['Type', 'Count']
        type_counts['Percentage'] = (type_counts['Count'] / num_fossils * 100).round(1)
        
        # Group distribution
        group_counts = df['F_GROUP'].value_counts().reset_index()
        group_counts.columns = ['Group', 'Count']
        group_counts['Percentage'] = (group_counts['Count'] / num_fossils * 100).round(1)
        
        # Create HTML content
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Fossil Dataset Summary</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
                h1, h2, h3 {{ color: #2c3e50; }}
                .container {{ max-width: 1200px; margin: 0 auto; }}
                .stats-box {{ background-color: #f8f9fa; border-radius: 5px; padding: 15px; margin-bottom: 20px; }}
                .stats-grid {{ display: flex; flex-wrap: wrap; }}
                .stat-item {{ flex: 1; min-width: 200px; margin: 10px; padding: 15px; background-color: #ffffff; 
                             border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); }}
                table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
                th, td {{ text-align: left; padding: 12px; }}
                th {{ background-color: #4CAF50; color: white; }}
                tr:nth-child(even) {{ background-color: #f2f2f2; }}
                .feature-list {{ column-count: 3; column-gap: 20px; }}
                .real-data {{ color: green; }}
                .simulated-data {{ color: orange; }}
            </style>
        </head>
        <body>
            <div class="container">
                <h1>Fossil Dataset Summary</h1>
                <div class="stats-box">
                    <h2>Overview</h2>
                    <div class="stats-grid">
                        <div class="stat-item">
                            <h3>Total Fossils</h3>
                            <p>{num_fossils}</p>
                        </div>
                        <div class="stat-item">
                            <h3>Unique Species</h3>
                            <p>{num_species}</p>
                        </div>
                        <div class="stat-item">
                            <h3>Locations</h3>
                            <p>{num_locations}</p>
                        </div>
                        <div class="stat-item">
                            <h3>Geological Periods</h3>
                            <p>{num_periods}</p>
                        </div>
                    </div>
                </div>
                
                <div class="stats-box">
                    <h2>Data Quality</h2>
                    <div class="stats-grid">
                        <div class="stat-item">
                            <h3>Satellite Data</h3>
                            <p class="real-data">{real_satellite} real samples</p>
                            <p class="simulated-data">{num_fossils - real_satellite} simulated samples</p>
                        </div>
                        <div class="stat-item">
                            <h3>Text Embeddings</h3>
                            <p class="real-data">{real_embeddings} with WangChanBERTa</p>
                            <p class="simulated-data">{num_fossils - real_embeddings} simulated</p>
                        </div>
                    </div>
                </div>
                
                <div class="stats-box">
                    <h2>Geological Period Distribution</h2>
                    <table>
                        <tr>
                            <th>Period</th>
                            <th>Count</th>
                            <th>Percentage</th>
                        </tr>
        """
        
        # Add period rows
        for _, row in period_counts.iterrows():
            html_content += f"""
                        <tr>
                            <td>{row['Period']}</td>
                            <td>{row['Count']}</td>
                            <td>{row['Percentage']}%</td>
                        </tr>
            """
        
        html_content += """
                    </table>
                </div>
                
                <div class="stats-box">
                    <h2>Fossil Type Distribution</h2>
                    <table>
                        <tr>
                            <th>Type</th>
                            <th>Count</th>
                            <th>Percentage</th>
                        </tr>
        """
        
        # Add type rows
        for _, row in type_counts.iterrows():
            html_content += f"""
                        <tr>
                            <td>{row['Type']}</td>
                            <td>{row['Count']}</td>
                            <td>{row['Percentage']}%</td>
                        </tr>
            """
        
        html_content += """
                    </table>
                </div>
                
                <div class="stats-box">
                    <h2>Fossil Group Distribution</h2>
                    <table>
                        <tr>
                            <th>Group</th>
                            <th>Count</th>
                            <th>Percentage</th>
                        </tr>
        """
        
        # Add group rows
        for _, row in group_counts.iterrows():
            html_content += f"""
                        <tr>
                            <td>{row['Group']}</td>
                            <td>{row['Count']}</td>
                            <td>{row['Percentage']}%</td>
                        </tr>
            """
        
        html_content += """
                    </table>
                </div>
                
                <div class="stats-box">
                    <h2>Features Available</h2>
                    <div class="feature-list">
        """
        
        # Add feature list
        categories = {
            'Identifiers': ['FOSSIL_ID', 'SCI_NAME', 'COM_NAME', 'F_GROUP', 'F_TYPE'],
            'Spatial': ['latitude', 'longitude', 'pangaea_lat', 'pangaea_lon', 'distance_from_centroid'],
            'Satellite': ['ndvi', 'red_band', 'green_band', 'blue_band', 'nir_band'],
            'Text': [col for col in df.columns if col.startswith('text_embedding_')],
            'Temporal': ['period_age_mya', 'geological_era'],
            'Statistical': ['species_frequency', 'location_frequency', 'province_density'],
            'Other': [col for col in df.columns if col not in 
                     ['FOSSIL_ID', 'SCI_NAME', 'COM_NAME', 'F_GROUP', 'F_TYPE',
                      'latitude', 'longitude', 'pangaea_lat', 'pangaea_lon', 'distance_from_centroid',
                      'ndvi', 'red_band', 'green_band', 'blue_band', 'nir_band',
                      'period_age_mya', 'geological_era', 'species_frequency', 'location_frequency', 'province_density'] and
                     not col.startswith('text_embedding_')]
        }
        
        for category, features in categories.items():
            html_content += f"<h3>{category}</h3><ul>"
            for feature in features:
                if feature in df.columns:
                    html_content += f"<li>{feature}</li>"
            html_content += "</ul>"
        
        html_content += """
                    </div>
                </div>
            </div>
        </body>
        </html>
        """
        
        # Write to file
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(html_content)
        
        print(f"  Dataset summary saved to {output_file}")
        
    except Exception as e:
        print(f"  Error creating dataset summary: {e}")

# Main execution
if __name__ == "__main__":
    # Load the cleaned data
    print("Loading dataset...")
    df = pd.read_csv('cleaned_fossil_data.csv')
    print(f"Loaded dataset with {len(df)} records")
    
    # Extract all features
    features_df = extract_features(
        df, 
        save_interim=True,
        use_real_data=True,
        satellite_batch_size=50,    # Process 20 locations per batch
        satellite_delay_seconds=0, # Wait 30 seconds between batches
        text_batch_size=100         # Process 100 texts per batch
    )
    
    # Report on the results
    print("\nFeature Extraction Summary:")
    print(f"Original columns: {len(df.columns)}")
    print(f"New features added: {len(features_df.columns) - len(df.columns)}")
    print(f"Total features: {len(features_df.columns)}")
    
    # List new features added
    new_features = [col for col in features_df.columns if col not in df.columns]
    print("\nNew features added:")
    for feature in new_features:
        print(f"  - {feature}")
        
    print("\nDataset is now ready for training!")

Loading dataset...
Loaded dataset with 9817 records
Starting feature extraction process with batch processing...
Adding coordinate features...
  Original missing UTM values: 0 rows
  Fixing 3224 rows with too large easting values
  Missing coordinates after conversion: 0 rows
  Coordinates filled - remaining missing: 0 rows
  Attempting real Pangaea coordinates calculation...
🔄 **Calculating Pangaea coordinates...**
✅ Successfully imported pygplates
📊 Found 13 rotation files and 21 feature files
🔍 Trying with Thailand Plate ID: 619
✅ Loaded rotation model: 1000-410_toy_introversion_simplified.rot
✅ Direct reconstruction works with plate ID 619!
🎯 Direct Success Rate: 99.96% (9813/9817 points)
✅ Using direct method with 1000-410_toy_introversion_simplified.rot and Plate ID 619 (Success Rate: 99.96%)
📊 Real coordinate ratio: 99.96% (9813/9817)
  Pangaea coordinates: 9817 real, 0 approximated
  Real coordinate ratio: 99.96%
  Calculating distance from province centroid...
  Calculating di

In [13]:
data = pd.read_csv('fossil_features_full.csv')
data

Unnamed: 0,FOSSIL_ID,SCI_NAME,COM_NAME,F_GROUP,F_TYPE,F_PART,DIS_NAME,PROVINCE_fossil,DISTRICT_fossil,TAMBOM,...,discovery_year,geological_era,species_frequency,location_frequency,formation_frequency,province_density,period_frequency,group_frequency,group_id,period_id
0,20200800005,Chonetinella andamanensis,แบรคิโอพอด,Invertebrate,แบรคิโอพอด,Dorsal internal mould and ventral external mou...,คณะวิจัยร่วม การลำดับชั้นหินและบรรพชีวินวิทยาโ...,พังงา,เกาะยาว,เกาะยาวใหญ่,...,2023,Paleozoic,5,340,7380,348,3485,2633,0,0
1,20200800005,Chonetinella andamanensis,แบรคิโอพอด,Invertebrate,แบรคิโอพอด,Dorsal internal mould and ventral external mou...,คณะวิจัยร่วม การลำดับชั้นหินและบรรพชีวินวิทยาโ...,พังงา,เกาะยาว,เกาะยาวใหญ่,...,2023,Paleozoic,5,340,204,348,3485,2633,0,0
2,20200800005,Chonetinella andamanensis,แบรคิโอพอด,Invertebrate,แบรคิโอพอด,Dorsal internal mould and ventral external mou...,คณะวิจัยร่วม การลำดับชั้นหินและบรรพชีวินวิทยาโ...,พังงา,เกาะยาว,เกาะยาวใหญ่,...,2023,Paleozoic,5,340,204,348,3485,2633,0,0
3,20200800005,Chonetinella andamanensis,แบรคิโอพอด,Invertebrate,แบรคิโอพอด,Dorsal internal mould and ventral external mou...,คณะวิจัยร่วม การลำดับชั้นหินและบรรพชีวินวิทยาโ...,พังงา,เกาะยาว,เกาะยาวใหญ่,...,2023,Paleozoic,5,340,79,348,3485,2633,0,0
4,20200800005,Chonetinella andamanensis,แบรคิโอพอด,Invertebrate,แบรคิโอพอด,Dorsal internal mould and ventral external mou...,คณะวิจัยร่วม การลำดับชั้นหินและบรรพชีวินวิทยาโ...,พังงา,เกาะยาว,เกาะยาวใหญ่,...,2023,Paleozoic,5,340,204,348,3485,2633,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9812,20230600023,Rusa unicolor,กวางป่า,Vertebrate,สัตว์กีบคู่,ฟัน,คณะสำรวจโบราณชีววิทยาไทย-ฝรั่งเศส,ชัยภูมิ,คอนสาร,ทุ่งลุยลาย,...,2023,Paleozoic,83,1041,694,1043,3485,7097,1,0
9813,20230600023,Rusa unicolor,กวางป่า,Vertebrate,สัตว์กีบคู่,ฟัน,คณะสำรวจโบราณชีววิทยาไทย-ฝรั่งเศส,ชัยภูมิ,คอนสาร,ทุ่งลุยลาย,...,2023,Paleozoic,83,1041,694,1043,3485,7097,1,0
9814,20230600023,Rusa unicolor,กวางป่า,Vertebrate,สัตว์กีบคู่,ฟัน,คณะสำรวจโบราณชีววิทยาไทย-ฝรั่งเศส,ชัยภูมิ,คอนสาร,ทุ่งลุยลาย,...,2023,Paleozoic,83,1041,7380,1043,3485,7097,1,0
9815,20230600068,Suidae indet.,หมู,Vertebrate,สัตว์กีบคู่,ฟัน,คณะสำรวจวิจัยโบราณชีววิทยาไทย-ญี่ปุ่น,พะเยา,เชียงม่วน,สระ,...,2023,Paleozoic,6,125,7380,125,291,7097,1,1
