In [1]:
import ee
import numpy as np
import datetime

# CONFIG: Adjust this to speed up or slow down
MAX_WORKERS = 50  # Try 10, 20, 30, 40 - higher = faster but may hit rate limits
chunk_size = 1000

# Specify pixel size and tile size at the top of the cell
PIXEL_SIZE = 10   # size of each pixel in meters
example_TILE_SIZE = 9     # number of pixels on one side of the tile
TILE_SIZE = 1

Parallel download function

In [2]:
from concurrent.futures import ThreadPoolExecutor, as_completed

def extract_single_site(site_idx, lat, lon, collection, unique_years, patch_size, n_dims):
    """Extract embeddings for one site - runs in parallel."""
    embeddings = np.full((n_dims, len(unique_years), patch_size, patch_size), np.nan, dtype=np.float32)
    
    try:
        point = ee.Geometry.Point([lon, lat])
        intersecting_tiles = collection.filterBounds(point)
        total_tiles = intersecting_tiles.size().getInfo()
        
        for tile_idx in range(total_tiles):
            try:
                img = ee.Image(intersecting_tiles.toList(1, tile_idx).get(0))
                time_start = img.get('system:time_start').getInfo()
                tile_year = datetime.datetime.fromtimestamp(time_start/1000).year
                
                if tile_year in unique_years:
                    year_idx = unique_years.index(tile_year)
                    half_size = patch_size // 2
                    
                    points = []
                    for row in range(patch_size):
                        for col in range(patch_size):
                            lat_offset = (row - half_size) * PIXEL_SIZE / 111000
                            lon_offset = (col - half_size) * PIXEL_SIZE / (111000 * np.cos(np.radians(lat)))
                            points.append(ee.Geometry.Point([lon + lon_offset, lat + lat_offset]))
                    
                    samples = img.sample(
                        region=ee.Geometry.MultiPoint(points),
                        scale=PIXEL_SIZE,
                        numPixels=patch_size*patch_size
                    ).getInfo()
                    
                    if samples['features']:
                        for idx, feature in enumerate(samples['features']):
                            if idx < patch_size*patch_size:
                                row, col = idx // patch_size, idx % patch_size
                                props = feature['properties']
                                for i in range(n_dims):
                                    val = props.get(f"A{i:02d}", None)
                                    if val is not None:
                                        embeddings[i, year_idx, row, col] = val
            except:
                pass
    except:
        pass
    
    return site_idx, embeddings


def extract_satellite_embeddings_parallel(coordinates, years=None, patch_size=example_TILE_SIZE):
    """Parallel download using MAX_WORKERS from config.
    
    Args:
        coordinates: List or array of [lat, lon] coordinates.
        years: List of years to extract embeddings for. If None, all years in the collection are used.
        patch_size: Size of patch for embedding extraction.
    """
    coords = np.array(coordinates)
    if coords.ndim == 1:
        coords = coords.reshape(1, -1)
    
    collection = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
    years_list = collection.aggregate_array('system:time_start').getInfo()
    all_years = sorted(list(set(
        datetime.datetime.fromtimestamp(ts/1000).year for ts in years_list
    )))
    
    if years is None:
        unique_years = all_years
    else:
        unique_years = [y for y in years if y in all_years]
        if len(unique_years) < len(years):
            missing = set(years) - set(unique_years)
            print(f"Warning: The following requested years are not available in the collection: {missing}")

    n_sites, n_years, n_dims = coords.shape[0], len(unique_years), 64
    embeddings = np.full((n_sites, n_dims, n_years, patch_size, patch_size), np.nan, dtype=np.float32)
    
    total_points = coords.shape[0]
    completed = 0
    
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_site = {
            executor.submit(extract_single_site, site_idx, lat, lon, collection, unique_years, patch_size, n_dims): site_idx
            for site_idx, (lat, lon) in enumerate(coords)
        }
        
        for future in as_completed(future_to_site):
            site_idx, site_embeddings = future.result()
            embeddings[site_idx] = site_embeddings
            completed += 1
            percent = 100.0 * completed / total_points
            print(f"Progress: {percent:6.3f}%", end='\r')
    
    return embeddings, coords, unique_years


Initialize connection to gee

In [3]:
import os
save_path = "3_AlphaEarth"
if not os.path.exists(f"{save_path}/0_2018gps_2017_2018embeddings.pkl"):
    # Authenticate and initialize the Earth Engine client library if not already initialized
    try:
        ee.Initialize()
    except Exception as e:
        ee.Authenticate()
        ee.Initialize()

Chunk into batches

In [4]:
# Test 100 sites first - adjust MAX_WORKERS above to optimize speed
import pickle
import time

with open("2_preprocessed.pkl", "rb") as f:
    df = pickle.load(f)

lat = df['gps_lat_2018']
long = df['gps_long_2018']   
coordinates_18 = [[la, lo] for la, lo in zip(lat, long)]
coordinate_18_packs = [coordinates_18[i:i+chunk_size] for i in range(0, len(coordinates_18), chunk_size)]



Download all batches (saves after each batch)

In [5]:
os.makedirs(save_path, exist_ok=True)

# Extract 2017 and 2018 for 2018 GPS coordinates
for i in range(0, len(coordinate_18_packs)):
    if os.path.exists(f"{save_path}/{i}_2018gps_2017_2018embeddings.pkl"):
        print(f"Batch {i} exists, skipping")
        continue
    pack = coordinate_18_packs[i]
    print(f"\nProcessing batch {i+1}/{len(coordinate_18_packs)} ({len(pack)} sites) with MAX_WORKERS={MAX_WORKERS}")
    embeddings, _, _ = extract_satellite_embeddings_parallel(pack, [2017, 2018], patch_size=TILE_SIZE)
    with open(f"{save_path}/{i}_2018gps_2017_2018embeddings.pkl", "wb") as f:
        pickle.dump(embeddings, f)
    print(f"\nSaved batch {i+1}")

Batch 0 exists, skipping
Batch 1 exists, skipping
Batch 2 exists, skipping
Batch 3 exists, skipping
Batch 4 exists, skipping
Batch 5 exists, skipping
Batch 6 exists, skipping
Batch 7 exists, skipping
Batch 8 exists, skipping
Batch 9 exists, skipping
Batch 10 exists, skipping
Batch 11 exists, skipping
Batch 12 exists, skipping
Batch 13 exists, skipping
Batch 14 exists, skipping
Batch 15 exists, skipping
Batch 16 exists, skipping
Batch 17 exists, skipping
Batch 18 exists, skipping
Batch 19 exists, skipping
Batch 20 exists, skipping
Batch 21 exists, skipping
Batch 22 exists, skipping
Batch 23 exists, skipping
Batch 24 exists, skipping
Batch 25 exists, skipping
Batch 26 exists, skipping
Batch 27 exists, skipping


In [6]:
# Get 2017 and 2018 for 2018 GPS coordinates
def get_multi_year(coordinates, year_idx):
    AlphaEarth_values = np.zeros((len(coordinates), 64))
    import math
    num_packs = math.ceil(len(coordinates) / chunk_size)
    idx = 0
    for i in range(num_packs):
        with open(f"{save_path}/{i}_2018gps_2017_2018embeddings.pkl", "rb") as f:
            pack_values = pickle.load(f)
        # Shape: (n_sites, n_dims, n_years, patch_size, patch_size)
        pack_values = pack_values[:, :, year_idx, 0, 0]  # Extract specific year and center pixel
        n = pack_values.shape[0]
        AlphaEarth_values[idx:idx+n] = pack_values[:min(n, len(coordinates)-idx)]
        idx += n
    return AlphaEarth_values

AE_18_2017 = get_multi_year(coordinates_18, 0)  # 2017 is first (sorted)
AE_18_2018 = get_multi_year(coordinates_18, 1)  # 2018 is second

In [7]:
import pandas as pd

cols_18_2017 = [f"AE{i:02d}_2018gps_2017" for i in range(64)]
cols_18_2018 = [f"AE{i:02d}_2018gps_2018" for i in range(64)]

df_ae_18_2017 = pd.DataFrame(AE_18_2017, columns=cols_18_2017)
df_ae_18_2018 = pd.DataFrame(AE_18_2018, columns=cols_18_2018)

df = pd.concat([df, df_ae_18_2017, df_ae_18_2018], axis=1)

In [8]:
df.to_pickle("3_with_gee.pkl")