# Learned Activations v3: Corrected Architecture & Fair Evaluation

## Fixes from v1/v2 based on independent review:

1. **Shared gating network** - Gate is now shared across all layers (not per-layer)
2. **Fair SatCLIP baseline** - Ridge regression on frozen embeddings (not MLP)
3. **Correct SIREN config** - w0_initial=30 for first layer, w0=1 for subsequent (matches SatCLIP)
4. **Correct HybridEncoder init** - SIREN-specific initialization when using sine activations
5. **Spatial blocking** - Grid-based train/test splits to prevent spatial leakage

## Architecture Summary

### SatCLIP (baseline)
- Spherical Harmonics (L=10: 100 features, L=40: 1600 features)
- SIREN network: w0_initial=30 (first layer), w0=1 (subsequent layers)
- Output: 256-dim embedding

### Our Learned Activations
- Fourier-parameterized: g(x) = sum_k a_k*sin(w_k*x) + b_k*cos(w_k*x)
- Spatially-varying: Mixture of experts with **shared** location-based gating
- Direct (lat, lon) input OR hybrid with SH features

In [None]:
# Setup
import os
import sys

if 'COLAB_GPU' in os.environ:
    !rm -rf sample_data .config satclip gpw_data 2>/dev/null
    !git clone https://github.com/1hamzaiqbal/satclip.git
    !pip install lightning torchgeo huggingface_hub rasterio --quiet

In [None]:
# Mount Google Drive and extract GPW data
from google.colab import drive
drive.mount('/content/drive')

import os
import zipfile

GPW_DIR = './gpw_data'
os.makedirs(GPW_DIR, exist_ok=True)

SOURCE_ZIP_PATH = '/content/drive/MyDrive/grad/learned_activations/dataverse_files.zip'

print("Extracting GPW data...")
with zipfile.ZipFile(SOURCE_ZIP_PATH, 'r') as z:
    z.extractall(GPW_DIR)

zip_path = os.path.join(GPW_DIR, 'gpw-v4-population-density-rev11_2020_15_min_tif.zip')
if os.path.exists(zip_path):
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(GPW_DIR)
    print("Extracted 15-min resolution")

print("Done!")

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import warnings
import math
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score

if 'COLAB_GPU' in os.environ:
    sys.path.append('./satclip/satclip')
    GPW_DIR = './gpw_data'
else:
    sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'satclip'))
    GPW_DIR = './gpw_data'

from huggingface_hub import hf_hub_download
from load import get_satclip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load SatCLIP models
print("Loading SatCLIP models...")
satclip_l10 = get_satclip(hf_hub_download("microsoft/SatCLIP-ViT16-L10", "satclip-vit16-l10.ckpt"), device=device)
satclip_l40 = get_satclip(hf_hub_download("microsoft/SatCLIP-ViT16-L40", "satclip-vit16-l40.ckpt"), device=device)
satclip_l10.eval()
satclip_l40.eval()
print("SatCLIP models loaded!")

---
## 1. Data Loading with Spatial Blocking

**Key fix**: Grid-based spatial blocking to prevent train/test leakage

In [None]:
Image.MAX_IMAGE_PIXELS = None

def load_gpw_raster(resolution='15_min', year=2020):
    tif_file = f"{GPW_DIR}/gpw_v4_population_density_rev11_{year}_{resolution}.tif"
    if not os.path.exists(tif_file):
        print(f"File not found: {tif_file}")
        return None, None
    
    img = Image.open(tif_file)
    data = np.array(img)
    height, width = data.shape
    
    lon_step = 360 / width
    lat_step = 180 / height
    lons = np.linspace(-180 + lon_step/2, 180 - lon_step/2, width)
    lats = np.linspace(90 - lat_step/2, -90 + lat_step/2, height)
    
    return data, (lons, lats)


def sample_with_spatial_blocking(data, coords, n_samples=10000, seed=42, bounds=None,
                                  grid_size=5.0, test_ratio=0.3):
    """
    Sample with grid-based spatial blocking to prevent leakage.
    
    Args:
        data: Population raster
        coords: (lons, lats) arrays
        n_samples: Total samples to draw
        seed: Random seed
        bounds: Optional (lon_min, lat_min, lon_max, lat_max)
        grid_size: Size of grid cells in degrees (default 5.0 = ~500km at equator)
        test_ratio: Fraction of grid cells held out for testing
    
    Returns:
        coords_train, values_train, coords_test, values_test
    """
    np.random.seed(seed)
    lons, lats = coords
    valid_mask = data > -1e30
    
    if bounds is not None:
        lon_min, lat_min, lon_max, lat_max = bounds
        lon_grid, lat_grid = np.meshgrid(lons, lats)
        bounds_mask = (
            (lon_grid >= lon_min) & (lon_grid <= lon_max) &
            (lat_grid >= lat_min) & (lat_grid <= lat_max)
        )
        valid_mask = valid_mask & bounds_mask
    else:
        lon_min, lon_max = -180, 180
        lat_min, lat_max = -90, 90
    
    # Create grid cell assignments
    n_lon_cells = int(np.ceil((lon_max - lon_min) / grid_size))
    n_lat_cells = int(np.ceil((lat_max - lat_min) / grid_size))
    total_cells = n_lon_cells * n_lat_cells
    
    # Randomly assign cells to train or test
    cell_indices = np.arange(total_cells)
    np.random.shuffle(cell_indices)
    n_test_cells = int(total_cells * test_ratio)
    test_cells = set(cell_indices[:n_test_cells])
    
    # Get valid points
    valid_idx = np.where(valid_mask)
    n_valid = len(valid_idx[0])
    
    if n_valid < n_samples:
        sample_idx = np.arange(n_valid)
    else:
        sample_idx = np.random.choice(n_valid, n_samples, replace=False)
    
    row_idx = valid_idx[0][sample_idx]
    col_idx = valid_idx[1][sample_idx]
    
    sample_lons = lons[col_idx]
    sample_lats = lats[row_idx]
    sample_values = data[row_idx, col_idx]
    
    # Assign each sample to train or test based on its grid cell
    train_mask = []
    for lon, lat in zip(sample_lons, sample_lats):
        lon_cell = int((lon - lon_min) / grid_size)
        lat_cell = int((lat - lat_min) / grid_size)
        lon_cell = min(lon_cell, n_lon_cells - 1)
        lat_cell = min(lat_cell, n_lat_cells - 1)
        cell_id = lat_cell * n_lon_cells + lon_cell
        train_mask.append(cell_id not in test_cells)
    
    train_mask = np.array(train_mask)
    
    coords_arr = np.stack([sample_lons, sample_lats], axis=1)
    
    coords_train = coords_arr[train_mask]
    coords_test = coords_arr[~train_mask]
    values_train = sample_values[train_mask]
    values_test = sample_values[~train_mask]
    
    print(f"  Spatial blocking: {n_lon_cells}x{n_lat_cells} grid, {len(test_cells)} test cells")
    print(f"  Train: {len(coords_train)}, Test: {len(coords_test)}")
    
    return coords_train, values_train, coords_test, values_test


# Load data
print("Loading population data...")
pop_data, pop_coords = load_gpw_raster('15_min')
print(f"Shape: {pop_data.shape}")

REGIONS = {
    'Global': None,
    'USA': (-125, 24, -66, 50),
    'Europe': (-10, 35, 40, 70),
    'China': (73, 18, 135, 54),
}

---
## 2. Corrected Model Architectures

### Key fixes:
1. **SIREN**: w0_initial=30 for first layer, w0=1 for subsequent (matches SatCLIP)
2. **SpatiallyVaryingActivation**: Shared gating network across all layers
3. **Proper initialization**: SIREN uses its own init, others use Kaiming

In [None]:
# =============================================================================
# LEARNED ACTIVATION FUNCTION
# =============================================================================

class LearnedActivation(nn.Module):
    """Fourier-parameterized learned activation function.
    
    g(x) = scale * (sum_k a_k*sin(w_k*x) + b_k*cos(w_k*x)) + bias
    
    Args:
        n_frequencies: Number of Fourier components K
        freq_init: 'linear', 'log', or 'random'
        learnable_freq: Whether frequencies are learnable
        max_freq: Maximum frequency value
    """
    def __init__(self, n_frequencies=25, freq_init='linear', learnable_freq=False, max_freq=10.0):
        super().__init__()
        self.n_frequencies = n_frequencies
        
        if freq_init == 'linear':
            freqs = torch.linspace(0.1, max_freq, n_frequencies)
        elif freq_init == 'log':
            freqs = torch.logspace(-1, np.log10(max_freq), n_frequencies)
        else:
            freqs = torch.rand(n_frequencies) * max_freq
        
        if learnable_freq:
            self.frequencies = nn.Parameter(freqs)
        else:
            self.register_buffer('frequencies', freqs)
        
        self.sin_coeffs = nn.Parameter(torch.randn(n_frequencies) * 0.1)
        self.cos_coeffs = nn.Parameter(torch.randn(n_frequencies) * 0.1)
        self.bias = nn.Parameter(torch.zeros(1))
        self.scale = nn.Parameter(torch.ones(1))
    
    def forward(self, x):
        wx = x.unsqueeze(-1) * self.frequencies
        sin_terms = torch.sin(wx) * self.sin_coeffs
        cos_terms = torch.cos(wx) * self.cos_coeffs
        result = (sin_terms + cos_terms).sum(dim=-1)
        return self.scale * result + self.bias


# =============================================================================
# SHARED GATING NETWORK (FIX: shared across all layers)
# =============================================================================

class SharedGatingNetwork(nn.Module):
    """Shared gating network for spatially-varying activations.
    
    This network is shared across all layers, producing expert weights
    based on location. This matches the paper description.
    
    Args:
        n_experts: Number of expert activations
        hidden_dim: Hidden dimension for gating MLP
    """
    def __init__(self, n_experts=8, hidden_dim=64):
        super().__init__()
        self.n_experts = n_experts
        
        self.gate = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_experts),
        )
    
    def forward(self, coords):
        """Get expert weights for given coordinates.
        
        Args:
            coords: (batch, 2) normalized coordinates in [-1, 1]
        
        Returns:
            weights: (batch, n_experts) softmax weights
        """
        return F.softmax(self.gate(coords), dim=-1)


class SpatiallyVaryingActivation(nn.Module):
    """Mixture of expert activations with SHARED location-based gating.
    
    FIX: The gating network is passed in and shared across layers,
    rather than each layer having its own gate.
    
    Args:
        shared_gate: SharedGatingNetwork instance (shared across layers)
        n_experts: Number of expert activation functions
        n_frequencies: Frequencies per expert activation
    """
    def __init__(self, shared_gate, n_experts=8, n_frequencies=25):
        super().__init__()
        self.shared_gate = shared_gate
        self.n_experts = n_experts
        
        # Each layer has its own expert activations, but shares the gate
        self.experts = nn.ModuleList([
            LearnedActivation(n_frequencies=n_frequencies)
            for _ in range(n_experts)
        ])
    
    def forward(self, x, coords):
        """Apply spatially-varying activation.
        
        Args:
            x: Input tensor (batch, features)
            coords: Normalized coordinates (batch, 2)
        
        Returns:
            Activated tensor, same shape as x
        """
        # Get gating weights from SHARED gate
        weights = self.shared_gate(coords)  # (batch, n_experts)
        
        # Apply each expert
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=-1)
        
        # Weighted combination
        weights = weights.unsqueeze(1)  # (batch, 1, n_experts)
        result = (expert_outputs * weights).sum(dim=-1)
        
        return result


# =============================================================================
# CORRECTED SIREN (matches SatCLIP: w0_initial=30, w0=1)
# =============================================================================

class Sine(nn.Module):
    """Sine activation with configurable omega."""
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0
    
    def forward(self, x):
        return torch.sin(self.w0 * x)


class SirenLayer(nn.Module):
    """Single SIREN layer with proper initialization.
    
    Matches SatCLIP's implementation in location_encoder.py.
    
    Args:
        dim_in: Input dimension
        dim_out: Output dimension
        w0: Frequency for sine activation
        is_first: Whether this is the first layer (uses different init)
        c: Constant for initialization (default 6.0)
    """
    def __init__(self, dim_in, dim_out, w0=1.0, is_first=False, c=6.0):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.w0 = w0
        
        self.linear = nn.Linear(dim_in, dim_out)
        self.activation = Sine(w0)
        
        # SIREN initialization
        self._init_weights(c)
    
    def _init_weights(self, c):
        with torch.no_grad():
            if self.is_first:
                # First layer: uniform(-1/dim_in, 1/dim_in)
                bound = 1.0 / self.dim_in
            else:
                # Subsequent layers: uniform(-sqrt(c/dim_in)/w0, sqrt(c/dim_in)/w0)
                bound = math.sqrt(c / self.dim_in) / self.w0
            
            self.linear.weight.uniform_(-bound, bound)
            if self.linear.bias is not None:
                self.linear.bias.uniform_(-bound, bound)
    
    def forward(self, x):
        return self.activation(self.linear(x))


print("Model components defined.")

In [None]:
# =============================================================================
# LOCATION ENCODERS
# =============================================================================

class LocationEncoderReLU(nn.Module):
    """Simple ReLU-based location encoder."""
    def __init__(self, input_dim=2, hidden_dim=256, output_dim=256, n_layers=3):
        super().__init__()
        
        dims = [input_dim] + [hidden_dim] * n_layers + [output_dim]
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:  # No activation after last layer
                layers.append(nn.ReLU())
        
        self.net = nn.Sequential(*layers)
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, coords):
        x = coords.clone()
        x[:, 0] = x[:, 0] / 180.0  # Normalize lon to [-1, 1]
        x[:, 1] = x[:, 1] / 90.0   # Normalize lat to [-1, 1]
        return self.net(x)


class LocationEncoderSIREN(nn.Module):
    """SIREN-based location encoder matching SatCLIP configuration.
    
    FIX: Uses w0_initial=30 for first layer, w0=1 for subsequent layers.
    """
    def __init__(self, input_dim=2, hidden_dim=256, output_dim=256, n_layers=3,
                 w0_initial=30.0, w0=1.0):
        super().__init__()
        
        layers = []
        for i in range(n_layers):
            is_first = (i == 0)
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = input_dim if is_first else hidden_dim
            
            layers.append(SirenLayer(
                dim_in=layer_dim_in,
                dim_out=hidden_dim,
                w0=layer_w0,
                is_first=is_first
            ))
        
        # Final layer (no activation, but still SIREN init)
        self.layers = nn.ModuleList(layers)
        self.final = nn.Linear(hidden_dim, output_dim)
        
        # Init final layer
        with torch.no_grad():
            bound = math.sqrt(6.0 / hidden_dim) / w0
            self.final.weight.uniform_(-bound, bound)
            self.final.bias.uniform_(-bound, bound)
    
    def forward(self, coords):
        x = coords.clone()
        x[:, 0] = x[:, 0] / 180.0
        x[:, 1] = x[:, 1] / 90.0
        
        for layer in self.layers:
            x = layer(x)
        
        return self.final(x)


class LocationEncoderLearned(nn.Module):
    """Location encoder with learned Fourier activations."""
    def __init__(self, input_dim=2, hidden_dim=256, output_dim=256, n_layers=3,
                 n_frequencies=25):
        super().__init__()
        
        dims = [input_dim] + [hidden_dim] * n_layers + [output_dim]
        self.linears = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)])
        self.activations = nn.ModuleList([LearnedActivation(n_frequencies=n_frequencies) for _ in range(n_layers)])
        
        self._init_weights()
    
    def _init_weights(self):
        for linear in self.linears:
            nn.init.kaiming_normal_(linear.weight)
            nn.init.zeros_(linear.bias)
    
    def forward(self, coords):
        x = coords.clone()
        x[:, 0] = x[:, 0] / 180.0
        x[:, 1] = x[:, 1] / 90.0
        
        for linear, act in zip(self.linears[:-1], self.activations):
            x = act(linear(x))
        x = self.linears[-1](x)
        return x


class LocationEncoderSpatial(nn.Module):
    """Location encoder with spatially-varying activations.
    
    FIX: Uses a SHARED gating network across all layers.
    """
    def __init__(self, input_dim=2, hidden_dim=256, output_dim=256, n_layers=3,
                 n_experts=8, n_frequencies=25, gate_hidden=64):
        super().__init__()
        
        # SHARED gating network (FIX: one gate for all layers)
        self.shared_gate = SharedGatingNetwork(n_experts=n_experts, hidden_dim=gate_hidden)
        
        dims = [input_dim] + [hidden_dim] * n_layers + [output_dim]
        self.linears = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)])
        
        # Each layer has its own experts, but they share the gate
        self.activations = nn.ModuleList([
            SpatiallyVaryingActivation(self.shared_gate, n_experts=n_experts, n_frequencies=n_frequencies)
            for _ in range(n_layers)
        ])
        
        self._init_weights()
    
    def _init_weights(self):
        for linear in self.linears:
            nn.init.kaiming_normal_(linear.weight)
            nn.init.zeros_(linear.bias)
    
    def forward(self, coords):
        x = coords.clone()
        x[:, 0] = x[:, 0] / 180.0
        x[:, 1] = x[:, 1] / 90.0
        
        norm_coords = x.clone()  # For gating
        
        for linear, act in zip(self.linears[:-1], self.activations):
            x = act(linear(x), norm_coords)
        x = self.linears[-1](x)
        return x


# Test architectures
print("\nArchitecture Parameter Counts:")
print("-" * 50)
for name, enc_cls in [('ReLU', LocationEncoderReLU), 
                       ('SIREN', LocationEncoderSIREN),
                       ('Learned', LocationEncoderLearned),
                       ('Spatial', LocationEncoderSpatial)]:
    enc = enc_cls()
    n_params = sum(p.numel() for p in enc.parameters())
    print(f"  {name:10s}: {n_params:>10,} params")

In [None]:
# =============================================================================
# HYBRID ENCODER (SH features + custom activations)
# =============================================================================

class HybridEncoder(nn.Module):
    """Spherical harmonics input + configurable activations.
    
    FIX: Uses SIREN-specific initialization when activation='siren'.
    
    Args:
        sh_model: SatCLIP location encoder (provides .posenc)
        hidden_dim: Hidden layer dimension
        output_dim: Output embedding dimension
        n_layers: Number of hidden layers
        activation: 'relu', 'siren', or 'learned'
        n_frequencies: For learned activations
        w0_initial: For SIREN first layer
        w0: For SIREN subsequent layers
    """
    def __init__(self, sh_model, hidden_dim=256, output_dim=256, n_layers=3,
                 activation='learned', n_frequencies=25, w0_initial=30.0, w0=1.0):
        super().__init__()
        self.sh_model = sh_model
        self.activation_type = activation
        
        # Freeze SH
        for param in self.sh_model.parameters():
            param.requires_grad = False
        
        # Get SH output dim
        with torch.no_grad():
            test_coord = torch.tensor([[0.0, 0.0]]).double().to(next(sh_model.parameters()).device)
            sh_out = sh_model.posenc(test_coord)
            sh_dim = sh_out.shape[-1]
        
        self.sh_dim = sh_dim
        print(f"  SH dim: {sh_dim}")
        
        # Build network based on activation type
        if activation == 'siren':
            # Use proper SIREN layers
            layers = []
            for i in range(n_layers):
                is_first = (i == 0)
                layer_w0 = w0_initial if is_first else w0
                layer_dim_in = sh_dim if is_first else hidden_dim
                layers.append(SirenLayer(layer_dim_in, hidden_dim, w0=layer_w0, is_first=is_first))
            
            self.layers = nn.ModuleList(layers)
            self.final = nn.Linear(hidden_dim, output_dim)
            
            # SIREN init for final layer
            with torch.no_grad():
                bound = math.sqrt(6.0 / hidden_dim) / w0
                self.final.weight.uniform_(-bound, bound)
                self.final.bias.uniform_(-bound, bound)
            
            self.use_siren = True
        else:
            # Standard Linear + activation
            dims = [sh_dim] + [hidden_dim] * n_layers + [output_dim]
            self.linears = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)])
            
            if activation == 'learned':
                self.activations = nn.ModuleList([LearnedActivation(n_frequencies=n_frequencies) for _ in range(n_layers)])
            else:  # relu
                self.activations = nn.ModuleList([nn.ReLU() for _ in range(n_layers)])
            
            # Kaiming init for non-SIREN
            for linear in self.linears:
                nn.init.kaiming_normal_(linear.weight)
                nn.init.zeros_(linear.bias)
            
            self.use_siren = False
    
    def forward(self, coords):
        # Get SH features
        with torch.no_grad():
            x = self.sh_model.posenc(coords.double()).float()
        
        if self.use_siren:
            for layer in self.layers:
                x = layer(x)
            x = self.final(x)
        else:
            for linear, act in zip(self.linears[:-1], self.activations):
                x = act(linear(x))
            x = self.linears[-1](x)
        
        return x


print("HybridEncoder defined.")

---
## 3. Evaluation Functions

In [None]:
def get_satclip_embeddings(model, coords, device, batch_size=512):
    """Extract SatCLIP embeddings for coordinates."""
    model.eval()
    embeddings = []
    coords_tensor = torch.tensor(coords, dtype=torch.float64)
    
    with torch.no_grad():
        for i in range(0, len(coords), batch_size):
            batch = coords_tensor[i:i+batch_size].to(device)
            emb = model(batch).cpu().numpy()
            embeddings.append(emb)
    
    return np.vstack(embeddings)


def evaluate_sklearn(X_train, y_train, X_test, y_test, alpha=1.0):
    """Evaluate using Ridge regression (fair SatCLIP comparison)."""
    model = Ridge(alpha=alpha)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    return r2_score(y_test, y_pred)


def evaluate_neural(encoder, coords_train, y_train, coords_test, y_test,
                   epochs=100, lr=1e-3, batch_size=256, device='cuda', verbose=True):
    """Train encoder end-to-end and evaluate."""
    
    class Predictor(nn.Module):
        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder
            self.head = nn.Sequential(
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.Linear(128, 1)
            )
        
        def forward(self, x):
            return self.head(self.encoder(x)).squeeze(-1)
    
    model = Predictor(encoder).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    train_coords = torch.tensor(coords_train, dtype=torch.float32)
    train_y = torch.tensor(np.log1p(y_train), dtype=torch.float32)
    test_coords = torch.tensor(coords_test, dtype=torch.float32)
    test_y = torch.tensor(np.log1p(y_test), dtype=torch.float32)
    
    train_dataset = torch.utils.data.TensorDataset(train_coords, train_y)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    best_r2 = -float('inf')
    
    for epoch in range(epochs):
        model.train()
        for coords_batch, y_batch in train_loader:
            coords_batch, y_batch = coords_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            loss = criterion(model(coords_batch), y_batch)
            loss.backward()
            optimizer.step()
        
        model.eval()
        with torch.no_grad():
            preds = model(test_coords.to(device)).cpu().numpy()
        r2 = r2_score(test_y.numpy(), preds)
        best_r2 = max(best_r2, r2)
        
        if verbose and (epoch + 1) % 25 == 0:
            print(f"    Epoch {epoch+1}/{epochs}: R²={r2:.4f}")
    
    return best_r2

---
## 4. Run Experiments

In [None]:
print("="*80)
print("EXPERIMENT: Corrected Learned Activations vs SatCLIP")
print("="*80)
print("\nFixes applied:")
print("  1. Spatial blocking for train/test splits")
print("  2. Ridge regression for SatCLIP baseline")
print("  3. Correct SIREN config (w0_initial=30, w0=1)")
print("  4. Shared gating network for spatial activations")

N_SAMPLES = 20000  # More samples to compensate for blocking
EPOCHS = 100
GRID_SIZE = 5.0  # 5 degree grid cells

all_results = []

for region_name, bounds in REGIONS.items():
    print(f"\n{'='*60}")
    print(f"Region: {region_name}")
    print(f"{'='*60}")
    
    # Sample with spatial blocking
    coords_train, values_train, coords_test, values_test = sample_with_spatial_blocking(
        pop_data, pop_coords, n_samples=N_SAMPLES, bounds=bounds,
        grid_size=GRID_SIZE, test_ratio=0.3
    )
    
    if len(coords_train) < 500 or len(coords_test) < 200:
        print("Skipping - too few samples")
        continue
    
    y_train_log = np.log1p(values_train)
    y_test_log = np.log1p(values_test)
    
    # =================================================================
    # BASELINES: SatCLIP with Ridge (fair comparison)
    # =================================================================
    print("\n--- SatCLIP Baselines (Ridge) ---")
    
    print("  SatCLIP L=10...")
    emb_train = get_satclip_embeddings(satclip_l10, coords_train, device)
    emb_test = get_satclip_embeddings(satclip_l10, coords_test, device)
    r2_l10 = evaluate_sklearn(emb_train, y_train_log, emb_test, y_test_log)
    print(f"    R²: {r2_l10:.4f}")
    all_results.append({'region': region_name, 'model': 'SatCLIP L=10', 'r2': r2_l10, 'type': 'baseline'})
    
    print("  SatCLIP L=40...")
    emb_train = get_satclip_embeddings(satclip_l40, coords_train, device)
    emb_test = get_satclip_embeddings(satclip_l40, coords_test, device)
    r2_l40 = evaluate_sklearn(emb_train, y_train_log, emb_test, y_test_log)
    print(f"    R²: {r2_l40:.4f}")
    all_results.append({'region': region_name, 'model': 'SatCLIP L=40', 'r2': r2_l40, 'type': 'baseline'})
    
    # =================================================================
    # DIRECT ENCODERS (our approaches)
    # =================================================================
    print("\n--- Direct Encoders (end-to-end) ---")
    
    print("  Direct + ReLU...")
    encoder = LocationEncoderReLU()
    r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test, epochs=EPOCHS, device=device)
    print(f"    Best R²: {r2:.4f}")
    all_results.append({'region': region_name, 'model': 'Direct + ReLU', 'r2': r2, 'type': 'direct'})
    
    print("  Direct + SIREN (corrected)...")
    encoder = LocationEncoderSIREN(w0_initial=30.0, w0=1.0)  # Matches SatCLIP
    r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test, epochs=EPOCHS, device=device)
    print(f"    Best R²: {r2:.4f}")
    all_results.append({'region': region_name, 'model': 'Direct + SIREN', 'r2': r2, 'type': 'direct'})
    
    print("  Direct + Learned...")
    encoder = LocationEncoderLearned(n_frequencies=25)
    r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test, epochs=EPOCHS, device=device)
    print(f"    Best R²: {r2:.4f}")
    all_results.append({'region': region_name, 'model': 'Direct + Learned', 'r2': r2, 'type': 'direct'})
    
    print("  Direct + Spatial (shared gate)...")
    encoder = LocationEncoderSpatial(n_experts=8, n_frequencies=25)
    r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test, epochs=EPOCHS, device=device)
    print(f"    Best R²: {r2:.4f}")
    all_results.append({'region': region_name, 'model': 'Direct + Spatial', 'r2': r2, 'type': 'direct'})

results_df = pd.DataFrame(all_results)
print(f"\n\nTotal experiments: {len(results_df)}")

In [None]:
# =================================================================
# HYBRID EXPERIMENTS (SH + learned activations)
# =================================================================

print("\n" + "="*80)
print("HYBRID EXPERIMENTS: SH features + different activations")
print("="*80)

hybrid_results = []

for region_name, bounds in REGIONS.items():
    print(f"\n{'='*60}")
    print(f"Region: {region_name}")
    print(f"{'='*60}")
    
    coords_train, values_train, coords_test, values_test = sample_with_spatial_blocking(
        pop_data, pop_coords, n_samples=N_SAMPLES, bounds=bounds,
        grid_size=GRID_SIZE, test_ratio=0.3, seed=123  # Different seed for variety
    )
    
    if len(coords_train) < 500:
        continue
    
    # Test hybrid with L=10
    for act_type in ['relu', 'siren', 'learned']:
        print(f"\n  SH(L=10) + {act_type}...")
        encoder = HybridEncoder(satclip_l10, activation=act_type, n_frequencies=25)
        r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test,
                            epochs=EPOCHS, device=device)
        print(f"    Best R²: {r2:.4f}")
        hybrid_results.append({'region': region_name, 'model': f'SH(L=10)+{act_type}', 'r2': r2})
    
    # Test hybrid with L=40
    for act_type in ['relu', 'learned']:
        print(f"\n  SH(L=40) + {act_type}...")
        encoder = HybridEncoder(satclip_l40, activation=act_type, n_frequencies=25)
        r2 = evaluate_neural(encoder, coords_train, values_train, coords_test, values_test,
                            epochs=EPOCHS, device=device)
        print(f"    Best R²: {r2:.4f}")
        hybrid_results.append({'region': region_name, 'model': f'SH(L=40)+{act_type}', 'r2': r2})

hybrid_df = pd.DataFrame(hybrid_results)

In [None]:
# =================================================================
# COMBINED RESULTS
# =================================================================

print("\n" + "="*80)
print("COMBINED RESULTS")
print("="*80)

# Main results
print("\n--- Direct Approaches ---")
pivot1 = results_df.pivot(index='model', columns='region', values='r2')
col_order = ['Global', 'USA', 'Europe', 'China']
pivot1 = pivot1[[c for c in col_order if c in pivot1.columns]]
print(pivot1.round(3).to_string())

# Hybrid results
if len(hybrid_df) > 0:
    print("\n--- Hybrid Approaches ---")
    pivot2 = hybrid_df.pivot(index='model', columns='region', values='r2')
    pivot2 = pivot2[[c for c in col_order if c in pivot2.columns]]
    print(pivot2.round(3).to_string())

# Save
all_combined = pd.concat([results_df, hybrid_df])
all_combined.to_csv('learned_activations_v3_results.csv', index=False)
print("\nResults saved to learned_activations_v3_results.csv")

In [None]:
# =================================================================
# VISUALIZATION
# =================================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Direct approaches comparison
ax = axes[0]
pivot1.T.plot(kind='bar', ax=ax)
ax.set_ylabel('R² Score')
ax.set_xlabel('Region')
ax.set_title('Direct Approaches\n(with spatial blocking)')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3, axis='y')
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

# Hybrid approaches
if len(hybrid_df) > 0:
    ax = axes[1]
    pivot2.T.plot(kind='bar', ax=ax)
    ax.set_ylabel('R² Score')
    ax.set_xlabel('Region')
    ax.set_title('Hybrid Approaches\n(SH features + activations)')
    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

plt.tight_layout()
plt.savefig('learned_activations_v3_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =================================================================
# KEY FINDINGS
# =================================================================

print("="*80)
print("KEY FINDINGS")
print("="*80)

print("\n1. SatCLIP BASELINE (with fair Ridge regression):")
for region in ['Global', 'USA', 'Europe', 'China']:
    l10 = results_df[(results_df['model'] == 'SatCLIP L=10') & (results_df['region'] == region)]['r2'].values
    l40 = results_df[(results_df['model'] == 'SatCLIP L=40') & (results_df['region'] == region)]['r2'].values
    if len(l10) > 0 and len(l40) > 0:
        print(f"  {region}: L=10={l10[0]:.3f}, L=40={l40[0]:.3f}, diff={l40[0]-l10[0]:+.3f}")

print("\n2. LEARNED ACTIVATIONS vs SatCLIP:")
for region in ['Global', 'USA', 'Europe', 'China']:
    l10 = results_df[(results_df['model'] == 'SatCLIP L=10') & (results_df['region'] == region)]['r2'].values
    learned = results_df[(results_df['model'] == 'Direct + Learned') & (results_df['region'] == region)]['r2'].values
    if len(l10) > 0 and len(learned) > 0:
        print(f"  {region}: Learned={learned[0]:.3f} vs L=10={l10[0]:.3f} ({learned[0]-l10[0]:+.3f})")

print("\n3. SPATIAL ACTIVATIONS (with shared gate):")
for region in ['Global', 'USA', 'Europe', 'China']:
    learned = results_df[(results_df['model'] == 'Direct + Learned') & (results_df['region'] == region)]['r2'].values
    spatial = results_df[(results_df['model'] == 'Direct + Spatial') & (results_df['region'] == region)]['r2'].values
    if len(learned) > 0 and len(spatial) > 0:
        print(f"  {region}: Spatial={spatial[0]:.3f} vs Learned={learned[0]:.3f} ({spatial[0]-learned[0]:+.3f})")

if len(hybrid_df) > 0:
    print("\n4. HYBRID APPROACHES (best per region):")
    for region in ['Global', 'USA', 'Europe', 'China']:
        region_data = hybrid_df[hybrid_df['region'] == region]
        if len(region_data) > 0:
            best = region_data.loc[region_data['r2'].idxmax()]
            print(f"  {region}: {best['model']} = {best['r2']:.3f}")

In [None]:
# Save summary
import json

summary = {
    'fixes_applied': [
        'Spatial blocking for train/test splits (5 degree grid)',
        'Ridge regression for SatCLIP baseline (not MLP)',
        'Correct SIREN config (w0_initial=30, w0=1)',
        'Shared gating network across layers for spatial activations',
        'SIREN-specific initialization in HybridEncoder'
    ],
    'direct_results': results_df.to_dict('records'),
    'hybrid_results': hybrid_df.to_dict('records') if len(hybrid_df) > 0 else [],
}

with open('learned_activations_v3_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("Summary saved to learned_activations_v3_summary.json")