## How This Geospatial Encoding Works

This notebook implements a **hierarchical geospatial encoder** that converts geographic coordinates into learnable embeddings. Here's how each component works:

### 1. `emb_dim(n)` Function
- Calculates embedding dimension based on vocabulary size
- Formula: `min(64, max(8, round(2*sqrt(n))))`
- Ensures reasonable embedding sizes (8-64 dimensions) 
- Larger vocabularies get bigger embeddings, but capped at 64

### 2. `HierGeoEncoder` Class
This is the main encoder that processes **multi-level geohash data**:

**Key Concepts:**
- **Geohashes** are strings that encode lat/lng coordinates (e.g., "9q8y" → San Francisco area)
- **Hierarchical levels** use different precisions: "9" → "9q" → "9q8" → "9q8y" (coarse → fine)
- Each level captures geographic information at different scales

**Architecture:**
- Creates separate embedding layers for each geohash level
- Combines all embeddings + optional extra features (like seasonal data)
- Passes through MLP to create final geographic representation

### 3. `GeoLateHead` Class
- Takes the geographic encoder and adds a classification head
- Maps geographic embeddings to class predictions (183 classes for fungi species)
- Useful for predicting species based on geographic location

## Forward Pass Example

Let's say you have a sample with geohash "9q8yv" and want to encode it hierarchically:

```
Input: 
- level_ids = [tensor([9]), tensor([9q]), tensor([9q8]), tensor([9q8y]), tensor([9q8yv])]
- extra = tensor([0.5, -0.8])  # e.g., month_sin, month_cos

Process:
1. Each embedding layer processes its level:
   - emb[0](9) → 16-dim vector
   - emb[1](9q) → 20-dim vector  
   - emb[2](9q8) → 24-dim vector
   - etc.

2. Concatenate all embeddings: [16 + 20 + 24 + ... + extra_dims]

3. Pass through MLP: 
   - Linear(total_dims → 256) → ReLU → Dropout
   - Linear(256 → 128) → ReLU → Dropout
   - Output: 128-dim geographic representation
```

**Why This Works:**
- **Multi-scale learning**: Captures both broad regions and precise locations
- **Learnable representations**: Neural network learns optimal geographic features
- **Flexible**: Can add temporal/seasonal information via `extra` features

In [1]:
import torch, torch.nn as nn, math

def emb_dim(n):  
    return int(min(64, max(8, round(2*math.sqrt(n)))))

class HierGeoEncoder(nn.Module):
    def __init__(self, vocab_levels, num_extra=0, out_dim=128, dropout=0.2):
        """
        vocab_levels: list of dicts (level vocab), one per precision (e.g., geohash[:3], [:4], ...)
        num_extra: extra numeric dims (e.g., month_sin, month_cos)
        """
        super().__init__()
        self.embs = nn.ModuleList([
            nn.Embedding(len(v), emb_dim(len(v)), padding_idx=0) for v in vocab_levels
        ])
        cat_dim = sum(emb.embedding_dim for emb in self.embs)
        self.num_extra = num_extra
        self.mlp = nn.Sequential(
            nn.Linear(cat_dim + num_extra, 256), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(256, out_dim), nn.ReLU(), nn.Dropout(dropout)
        )

    def forward(self, level_ids, extra=None):
        # level_ids: list of LongTensors [N] (one per level)
        zs = [emb(ids) for emb, ids in zip(self.embs, level_ids)]   # [(N,d_i)...]
        z = torch.cat(zs, dim=-1)                                   # (N, sum d_i)
        if self.num_extra and extra is not None:
            z = torch.cat([z, extra], dim=-1)
        return self.mlp(z)                                          # (N, out_dim)

class GeoLateHead(nn.Module):
    def __init__(self, geo_enc: HierGeoEncoder, n_classes=183):
        super().__init__()
        self.geo = geo_enc
        self.clf = nn.Linear(geo_enc.mlp[-2].out_features, n_classes)
    def forward(self, level_ids, extra=None):
        z = self.geo(level_ids, extra)      # (N, D)
        return self.clf(z)                  # (N, C) metadata logits


v2 with date