# Phase 1: Core 2x2 Comparison

**Goal**: Establish whether learned activations work at all

## The 2x2 Grid

| | SIREN | Learned Acts |
|---|-------|--------------|
| **Raw coords** | Baseline | Test: Can learned acts discover frequencies? |
| **SH features** | Baseline | Test: Better nonlinearity than SIREN? |

## Key Metrics
- R² score
- Parameter count  
- **Efficiency**: R² per 10K parameters

In [None]:
# Setup
import os
import sys

if 'COLAB_GPU' in os.environ:
    !rm -rf sample_data .config satclip gpw_data 2>/dev/null
    !git clone https://github.com/1hamzaiqbal/satclip.git
    !pip install lightning torchgeo huggingface_hub rasterio --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import zipfile

GPW_DIR = './gpw_data'
os.makedirs(GPW_DIR, exist_ok=True)

SOURCE_ZIP_PATH = '/content/drive/MyDrive/grad/learned_activations/dataverse_files.zip'

print("Extracting GPW data...")
with zipfile.ZipFile(SOURCE_ZIP_PATH, 'r') as z:
    z.extractall(GPW_DIR)

zip_path = os.path.join(GPW_DIR, 'gpw-v4-population-density-rev11_2020_15_min_tif.zip')
if os.path.exists(zip_path):
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(GPW_DIR)

print("Done!")

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import math
import time
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

from sklearn.metrics import r2_score

if 'COLAB_GPU' in os.environ:
    sys.path.append('./satclip/satclip')
else:
    sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'satclip'))

from huggingface_hub import hf_hub_download
from load import get_satclip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Load SatCLIP L=10 for SH features
print("Loading SatCLIP L=10...")
satclip_l10 = get_satclip(hf_hub_download("microsoft/SatCLIP-ViT16-L10", "satclip-vit16-l10.ckpt"), device=device)
satclip_l10.eval()
print("Done!")

---
## Configuration

In [None]:
# =============================================================================
# CONFIGURATION - Keep it simple for Phase 1
# =============================================================================

CONFIG = {
    # Architecture
    'n_layers': 3,
    'hidden_dim': 256,
    'output_dim': 256,
    
    # SIREN (matches SatCLIP)
    'w0_initial': 30.0,
    'w0': 1.0,
    
    # Learned activations
    'n_frequencies': 25,
    
    # Training
    'n_samples': 15000,
    'epochs': 100,
    'batch_size': 256,
    'lr': 1e-3,
    
    # Spatial blocking
    'grid_size': 5.0,
    'test_ratio': 0.3,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

---
## Data Loading

In [None]:
Image.MAX_IMAGE_PIXELS = None

def load_gpw_raster():
    tif_file = f"{GPW_DIR}/gpw_v4_population_density_rev11_2020_15_min.tif"
    img = Image.open(tif_file)
    data = np.array(img)
    h, w = data.shape
    lons = np.linspace(-180 + 180/w, 180 - 180/w, w)
    lats = np.linspace(90 - 90/h, -90 + 90/h, h)
    return data, (lons, lats)


def sample_with_spatial_blocking(data, coords, cfg, bounds=None, seed=42):
    """Sample with grid-based spatial blocking."""
    np.random.seed(seed)
    lons, lats = coords
    valid_mask = data > -1e30
    
    if bounds:
        lon_min, lat_min, lon_max, lat_max = bounds
        lon_grid, lat_grid = np.meshgrid(lons, lats)
        valid_mask &= (lon_grid >= lon_min) & (lon_grid <= lon_max)
        valid_mask &= (lat_grid >= lat_min) & (lat_grid <= lat_max)
    else:
        lon_min, lon_max, lat_min, lat_max = -180, 180, -90, 90
    
    # Grid cells
    n_lon = int(np.ceil((lon_max - lon_min) / cfg['grid_size']))
    n_lat = int(np.ceil((lat_max - lat_min) / cfg['grid_size']))
    n_cells = n_lon * n_lat
    
    # Assign cells to test
    cell_ids = np.arange(n_cells)
    np.random.shuffle(cell_ids)
    test_cells = set(cell_ids[:int(n_cells * cfg['test_ratio'])])
    
    # Sample points
    valid_idx = np.where(valid_mask)
    n_valid = len(valid_idx[0])
    sample_idx = np.random.choice(n_valid, min(cfg['n_samples'], n_valid), replace=False)
    
    rows, cols = valid_idx[0][sample_idx], valid_idx[1][sample_idx]
    sample_lons, sample_lats = lons[cols], lats[rows]
    sample_vals = data[rows, cols]
    
    # Split by grid cell
    train_mask = []
    for lon, lat in zip(sample_lons, sample_lats):
        cell = int((lat - lat_min) / cfg['grid_size']) * n_lon + int((lon - lon_min) / cfg['grid_size'])
        cell = min(cell, n_cells - 1)
        train_mask.append(cell not in test_cells)
    train_mask = np.array(train_mask)
    
    coords_arr = np.stack([sample_lons, sample_lats], axis=1)
    
    return (
        coords_arr[train_mask], sample_vals[train_mask],
        coords_arr[~train_mask], sample_vals[~train_mask]
    )


# Load data
print("Loading data...")
pop_data, pop_coords = load_gpw_raster()
print(f"Raster shape: {pop_data.shape}")

# Sample with blocking
coords_train, vals_train, coords_test, vals_test = sample_with_spatial_blocking(
    pop_data, pop_coords, CONFIG
)
print(f"Train: {len(coords_train)}, Test: {len(coords_test)}")

---
## Model Definitions (Minimal)

In [None]:
# =============================================================================
# LEARNED ACTIVATION
# =============================================================================

class LearnedActivation(nn.Module):
    """Fourier-parameterized activation: g(x) = Σ a_k sin(ω_k x) + b_k cos(ω_k x)"""
    
    def __init__(self, n_freq=25, max_freq=10.0):
        super().__init__()
        self.register_buffer('freqs', torch.linspace(0.1, max_freq, n_freq))
        self.sin_c = nn.Parameter(torch.randn(n_freq) * 0.1)
        self.cos_c = nn.Parameter(torch.randn(n_freq) * 0.1)
        self.scale = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        wx = x.unsqueeze(-1) * self.freqs
        out = (torch.sin(wx) * self.sin_c + torch.cos(wx) * self.cos_c).sum(-1)
        return self.scale * out + self.bias


# =============================================================================
# SIREN LAYER (matches SatCLIP)
# =============================================================================

class SirenLayer(nn.Module):
    """SIREN layer with proper initialization."""
    
    def __init__(self, in_dim, out_dim, w0=1.0, is_first=False):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.w0 = w0
        
        # SIREN init
        with torch.no_grad():
            if is_first:
                bound = 1.0 / in_dim
            else:
                bound = math.sqrt(6.0 / in_dim) / w0
            self.linear.weight.uniform_(-bound, bound)
            self.linear.bias.uniform_(-bound, bound)
    
    def forward(self, x):
        return torch.sin(self.w0 * self.linear(x))


# =============================================================================
# THE 4 MODELS
# =============================================================================

class RawSIREN(nn.Module):
    """Raw coords + SIREN (baseline)"""
    
    def __init__(self, cfg):
        super().__init__()
        h = cfg['hidden_dim']
        self.layers = nn.ModuleList([
            SirenLayer(2, h, w0=cfg['w0_initial'], is_first=True),
            *[SirenLayer(h, h, w0=cfg['w0']) for _ in range(cfg['n_layers'] - 1)]
        ])
        self.final = nn.Linear(h, cfg['output_dim'])
        # SIREN init for final
        with torch.no_grad():
            b = math.sqrt(6.0 / h) / cfg['w0']
            self.final.weight.uniform_(-b, b)
            self.final.bias.uniform_(-b, b)
    
    def forward(self, coords):
        x = coords / torch.tensor([180., 90.], device=coords.device)
        for layer in self.layers:
            x = layer(x)
        return self.final(x)


class RawLearned(nn.Module):
    """Raw coords + Learned activations (test)"""
    
    def __init__(self, cfg):
        super().__init__()
        h = cfg['hidden_dim']
        self.linears = nn.ModuleList([
            nn.Linear(2, h),
            *[nn.Linear(h, h) for _ in range(cfg['n_layers'] - 1)],
            nn.Linear(h, cfg['output_dim'])
        ])
        self.acts = nn.ModuleList([LearnedActivation(cfg['n_frequencies']) for _ in range(cfg['n_layers'])])
        
        for lin in self.linears:
            nn.init.kaiming_normal_(lin.weight)
            nn.init.zeros_(lin.bias)
    
    def forward(self, coords):
        x = coords / torch.tensor([180., 90.], device=coords.device)
        for lin, act in zip(self.linears[:-1], self.acts):
            x = act(lin(x))
        return self.linears[-1](x)


class SHSIREN(nn.Module):
    """SH features + SIREN (baseline, like SatCLIP)"""
    
    def __init__(self, cfg, sh_model):
        super().__init__()
        self.sh_model = sh_model
        for p in sh_model.parameters():
            p.requires_grad = False
        
        # Get SH dim
        with torch.no_grad():
            sh_dim = sh_model.posenc(torch.zeros(1, 2).double().to(next(sh_model.parameters()).device)).shape[-1]
        self.sh_dim = sh_dim
        
        h = cfg['hidden_dim']
        self.layers = nn.ModuleList([
            SirenLayer(sh_dim, h, w0=cfg['w0_initial'], is_first=True),
            *[SirenLayer(h, h, w0=cfg['w0']) for _ in range(cfg['n_layers'] - 1)]
        ])
        self.final = nn.Linear(h, cfg['output_dim'])
        with torch.no_grad():
            b = math.sqrt(6.0 / h) / cfg['w0']
            self.final.weight.uniform_(-b, b)
            self.final.bias.uniform_(-b, b)
    
    def forward(self, coords):
        with torch.no_grad():
            x = self.sh_model.posenc(coords.double()).float()
        for layer in self.layers:
            x = layer(x)
        return self.final(x)


class SHLearned(nn.Module):
    """SH features + Learned activations (test)"""
    
    def __init__(self, cfg, sh_model):
        super().__init__()
        self.sh_model = sh_model
        for p in sh_model.parameters():
            p.requires_grad = False
        
        with torch.no_grad():
            sh_dim = sh_model.posenc(torch.zeros(1, 2).double().to(next(sh_model.parameters()).device)).shape[-1]
        self.sh_dim = sh_dim
        
        h = cfg['hidden_dim']
        self.linears = nn.ModuleList([
            nn.Linear(sh_dim, h),
            *[nn.Linear(h, h) for _ in range(cfg['n_layers'] - 1)],
            nn.Linear(h, cfg['output_dim'])
        ])
        self.acts = nn.ModuleList([LearnedActivation(cfg['n_frequencies']) for _ in range(cfg['n_layers'])])
        
        for lin in self.linears:
            nn.init.kaiming_normal_(lin.weight)
            nn.init.zeros_(lin.bias)
    
    def forward(self, coords):
        with torch.no_grad():
            x = self.sh_model.posenc(coords.double()).float()
        for lin, act in zip(self.linears[:-1], self.acts):
            x = act(lin(x))
        return self.linears[-1](x)


# Print param counts
print("\nParameter counts:")
for name, cls in [('RawSIREN', RawSIREN), ('RawLearned', RawLearned)]:
    m = cls(CONFIG)
    print(f"  {name}: {sum(p.numel() for p in m.parameters()):,}")

for name, cls in [('SHSIREN', SHSIREN), ('SHLearned', SHLearned)]:
    m = cls(CONFIG, satclip_l10)
    print(f"  {name} (SH dim={m.sh_dim}): {sum(p.numel() for p in m.parameters() if p.requires_grad):,}")

---
## Training

In [None]:
class Predictor(nn.Module):
    """Encoder + prediction head."""
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1)
        )
    
    def forward(self, x):
        return self.head(self.encoder(x)).squeeze(-1)


def train_and_eval(encoder, coords_train, vals_train, coords_test, vals_test, cfg):
    """Train encoder and return metrics."""
    model = Predictor(encoder).to(device)
    opt = optim.Adam(model.parameters(), lr=cfg['lr'])
    loss_fn = nn.MSELoss()
    
    # Data
    train_X = torch.tensor(coords_train, dtype=torch.float32)
    train_y = torch.tensor(np.log1p(vals_train), dtype=torch.float32)
    test_X = torch.tensor(coords_test, dtype=torch.float32).to(device)
    test_y = torch.tensor(np.log1p(vals_test), dtype=torch.float32)
    
    loader = DataLoader(TensorDataset(train_X, train_y), batch_size=cfg['batch_size'], shuffle=True)
    
    start = time.time()
    best_r2 = -float('inf')
    
    for epoch in range(cfg['epochs']):
        model.train()
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            opt.zero_grad()
            loss_fn(model(X), y).backward()
            opt.step()
        
        model.eval()
        with torch.no_grad():
            pred = model(test_X).cpu().numpy()
        r2 = r2_score(test_y.numpy(), pred)
        best_r2 = max(best_r2, r2)
        
        if (epoch + 1) % 20 == 0:
            print(f"    Epoch {epoch+1}: R²={r2:.4f}")
    
    train_time = time.time() - start
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return {
        'r2': best_r2,
        'params': n_params,
        'efficiency': best_r2 / (n_params / 10000),
        'time': train_time
    }

---
## Run the 2x2 Comparison

In [None]:
print("="*70)
print("PHASE 1: Core 2x2 Comparison")
print("="*70)

results = []

# 1. Raw + SIREN
print("\n[1/4] Raw + SIREN (baseline)...")
encoder = RawSIREN(CONFIG)
res = train_and_eval(encoder, coords_train, vals_train, coords_test, vals_test, CONFIG)
res['model'] = 'Raw + SIREN'
res['encoding'] = 'Raw'
res['activation'] = 'SIREN'
results.append(res)
print(f"    -> R²={res['r2']:.4f}, Params={res['params']:,}, Efficiency={res['efficiency']:.4f}")

# 2. Raw + Learned
print("\n[2/4] Raw + Learned (test)...")
encoder = RawLearned(CONFIG)
res = train_and_eval(encoder, coords_train, vals_train, coords_test, vals_test, CONFIG)
res['model'] = 'Raw + Learned'
res['encoding'] = 'Raw'
res['activation'] = 'Learned'
results.append(res)
print(f"    -> R²={res['r2']:.4f}, Params={res['params']:,}, Efficiency={res['efficiency']:.4f}")

# 3. SH + SIREN
print("\n[3/4] SH + SIREN (baseline)...")
encoder = SHSIREN(CONFIG, satclip_l10)
res = train_and_eval(encoder, coords_train, vals_train, coords_test, vals_test, CONFIG)
res['model'] = 'SH + SIREN'
res['encoding'] = 'SH(L=10)'
res['activation'] = 'SIREN'
results.append(res)
print(f"    -> R²={res['r2']:.4f}, Params={res['params']:,}, Efficiency={res['efficiency']:.4f}")

# 4. SH + Learned
print("\n[4/4] SH + Learned (test)...")
encoder = SHLearned(CONFIG, satclip_l10)
res = train_and_eval(encoder, coords_train, vals_train, coords_test, vals_test, CONFIG)
res['model'] = 'SH + Learned'
res['encoding'] = 'SH(L=10)'
res['activation'] = 'Learned'
results.append(res)
print(f"    -> R²={res['r2']:.4f}, Params={res['params']:,}, Efficiency={res['efficiency']:.4f}")

df = pd.DataFrame(results)
print("\nDone!")

---
## Results

In [None]:
print("="*70)
print("RESULTS")
print("="*70)

# Format as 2x2 table
print("\n--- R² Scores ---")
pivot_r2 = df.pivot(index='encoding', columns='activation', values='r2')
print(pivot_r2.round(4).to_string())

print("\n--- Efficiency (R² per 10K params) ---")
pivot_eff = df.pivot(index='encoding', columns='activation', values='efficiency')
print(pivot_eff.round(4).to_string())

print("\n--- Full Results ---")
print(df[['model', 'r2', 'params', 'efficiency', 'time']].round(4).to_string(index=False))

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# R² comparison
ax = axes[0]
x = np.arange(2)
w = 0.35
siren_r2 = [df[df['model'] == 'Raw + SIREN']['r2'].values[0],
            df[df['model'] == 'SH + SIREN']['r2'].values[0]]
learned_r2 = [df[df['model'] == 'Raw + Learned']['r2'].values[0],
              df[df['model'] == 'SH + Learned']['r2'].values[0]]
ax.bar(x - w/2, siren_r2, w, label='SIREN', color='steelblue')
ax.bar(x + w/2, learned_r2, w, label='Learned', color='coral')
ax.set_xticks(x)
ax.set_xticklabels(['Raw coords', 'SH(L=10)'])
ax.set_ylabel('R² Score')
ax.set_title('Performance Comparison')
ax.legend()
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3, axis='y')

# Efficiency comparison
ax = axes[1]
siren_eff = [df[df['model'] == 'Raw + SIREN']['efficiency'].values[0],
             df[df['model'] == 'SH + SIREN']['efficiency'].values[0]]
learned_eff = [df[df['model'] == 'Raw + Learned']['efficiency'].values[0],
               df[df['model'] == 'SH + Learned']['efficiency'].values[0]]
ax.bar(x - w/2, siren_eff, w, label='SIREN', color='steelblue')
ax.bar(x + w/2, learned_eff, w, label='Learned', color='coral')
ax.set_xticks(x)
ax.set_xticklabels(['Raw coords', 'SH(L=10)'])
ax.set_ylabel('Efficiency (R² / 10K params)')
ax.set_title('Parameter Efficiency')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('phase1_results.png', dpi=150)
plt.show()

In [None]:
# Key findings
print("="*70)
print("KEY FINDINGS")
print("="*70)

raw_siren = df[df['model'] == 'Raw + SIREN'].iloc[0]
raw_learned = df[df['model'] == 'Raw + Learned'].iloc[0]
sh_siren = df[df['model'] == 'SH + SIREN'].iloc[0]
sh_learned = df[df['model'] == 'SH + Learned'].iloc[0]

print("\n1. Can learned activations discover frequencies (Raw coords)?")
diff = raw_learned['r2'] - raw_siren['r2']
if diff > 0.01:
    print(f"   YES! Learned ({raw_learned['r2']:.3f}) > SIREN ({raw_siren['r2']:.3f}) by {diff:+.3f}")
elif diff < -0.01:
    print(f"   NO. SIREN ({raw_siren['r2']:.3f}) > Learned ({raw_learned['r2']:.3f}) by {-diff:.3f}")
else:
    print(f"   SIMILAR. SIREN={raw_siren['r2']:.3f}, Learned={raw_learned['r2']:.3f}")

print("\n2. Are learned activations a better nonlinearity than SIREN (with SH)?")
diff = sh_learned['r2'] - sh_siren['r2']
if diff > 0.01:
    print(f"   YES! Learned ({sh_learned['r2']:.3f}) > SIREN ({sh_siren['r2']:.3f}) by {diff:+.3f}")
elif diff < -0.01:
    print(f"   NO. SIREN ({sh_siren['r2']:.3f}) > Learned ({sh_learned['r2']:.3f}) by {-diff:.3f}")
else:
    print(f"   SIMILAR. SIREN={sh_siren['r2']:.3f}, Learned={sh_learned['r2']:.3f}")

print("\n3. Does SH encoding help?")
best_raw = max(raw_siren['r2'], raw_learned['r2'])
best_sh = max(sh_siren['r2'], sh_learned['r2'])
print(f"   Best Raw: {best_raw:.3f}, Best SH: {best_sh:.3f} ({best_sh - best_raw:+.3f})")

print("\n4. Parameter efficiency winner:")
best_eff = df.loc[df['efficiency'].idxmax()]
print(f"   {best_eff['model']}: {best_eff['efficiency']:.4f} R²/10K params")

In [None]:
# Save results
df.to_csv('phase1_results.csv', index=False)
print("Results saved to phase1_results.csv")

# Next steps
print("\n" + "="*70)
print("NEXT STEPS")
print("="*70)
print("""
Based on these results:

If Learned ≈ SIREN on Raw coords:
  -> Learned activations CAN discover frequencies
  -> Proceed to Phase 2: Try different activation types (splines, etc.)

If Learned < SIREN on Raw coords:
  -> May need more frequencies or different architecture
  -> Focus on SH + Learned path

If SH + Learned > SH + SIREN:
  -> Better nonlinearity confirmed!
  -> This is the promising path for improvement

See EXPERIMENT_ROADMAP.md for full plan.
""")