# <span style="color: orange;">Adapting the Biodiversity Foundation Model (BFM) for different data formats</span>

## <span style="color: orange;">Introduction</span>
The Biodiversity Foundation Model (BFM) is a flexible architecture designed to handle complex environmental data - from extinction indexes to atmospheric data. This notebook aims to demonstrate how one could adapt this architecture for different types of data and use cases.

We will (try to) cover:
- Understanding the BFM architecture;
- Key components that make BFM adaptable;
- Step-by-step guide to adapting BFM for new data formats;
- Practical example with Air Quality data;

Hope you'll get a better understanding of what this project is all about and what can be done with it (:  
Cheers,  
Sebastian Gribincea

## <span style="color: orange;">BFM Architecture Overview</span>

The BFM consists of three main components:
1. **Encoder**: Processes various input data types and formats, and creates a unified representation of shape [B, L, D] (check @encoder.py in )
2. **Backbone**: Handles temporal and spatial relationships (using either MViT or Swin Transformer)
3. **Decoder**: Generates predictions in the required output format

Key features that make BFM adaptable:
- Highly modular design, both within and between BFM building blocks;
- Flexible input handling (thanks to Perceiver IO);
- Lots of configurable hyperparameters.  
<span style="color: orange;">**Careful on this one!**</span>, multiple assertion errors were added across the code files, to ensure we are properly handling the data);  
- Adaptable position encodings;
- ..and many others

In [31]:
from typing import Literal
import torch
import torch.nn as nn

# simplified BFM structure, for a smoother start
class SimplifiedBFM(nn.Module):
    def __init__(
        self,
        input_vars: dict,
        embed_dim: int = 1024,
        num_latent_tokens: int = 8,
        backbone_type: Literal["swin", "mvit"] = "mvit",
        **kwargs
    ):
        super().__init__()
        # Three main components
        self.encoder = self._create_encoder(input_vars, embed_dim, **kwargs)
        self.backbone = self._create_backbone(backbone_type, embed_dim, **kwargs)
        self.decoder = self._create_decoder(input_vars, embed_dim, **kwargs)
    
    def forward(self, x, metadata):
        # encode
        encoded = self.encoder(x)
        ### any other processing necessary to the embeddings before passing them to the backbone ###
        # process
        processed = self.backbone(encoded) # should have the same [B, L, D] shape as the encoded (and maybe slightly modified) embeddings
        # decode
        output = self.decoder(processed)
        return output

## <span style="color: orange;">Why Adapt BFM?</span>

The BFM was originally designed for biodiversity data (and any related to variables, like atmospheric ones), but its architecture makes it potentially suitable for various environmental and time-series prediction tasks. Benefits of adaptation include:

1. **Reusable components**: core mechanisms like Perceiver IO and temporal processing can be preserved;
2. **Proven architecture**: Built on tested transformer-based approaches;
3. **Flexible I/O**: Can handle various data formats and tasks (i.e., modality-/task-agnosticity inherent by design);
4. **Scalable design**: Works with different data dimensions and temporal scales;

In the following sections, we'll explore how to adapt this architecture for different use cases, using the Air Quality Forecasting Model (AQFM) as our examplar.

# <span style="color: orange;">Original BFM architecture analysis</span>

In this section, we will analyze the original components of the BFM. Understanding these components is, as you might've already guessed, rather important for adapting the architecture to different data formats.

As mentioned, we have these key components:
1. **Encoder**;
2. **Decoder**;
3. **Backbone**.

Before adapting the BFM, we need to understand its core components. Let's analyze a simplified version that handles just two types of variables:
- Surface variables (2D spatial data)
- Atmospheric variables (3D data with pressure levels)

This simplification will help us understand the essential patterns without getting lost in complexity.

Let's start by examining the encoder in detail!

In [32]:
# all the necessary imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from dataclasses import dataclass
from typing import Dict, List, Tuple
from datetime import datetime, timedelta
from einops import rearrange, repeat


In [33]:
# we will use a simplified Batch structure, again, to not have a too long notebook
@dataclass
class SimpleBatchMetadata:
    """Simplified metadata structure"""
    latitudes: torch.Tensor
    longitudes: torch.Tensor
    pressure_levels: tuple[int, ...]
    timestamp: tuple[datetime, datetime]

@dataclass
class SimpleBatch:
    """Simplified batch structure with some of the used variable types"""
    surface_variables: dict[str, torch.Tensor]
    atmospheric_variables: dict[str, torch.Tensor]
    batch_metadata: SimpleBatchMetadata

## <span style="color: orange;">BFM Encoder Analysis</span>

The encoder is responsible for converting various physical variables into a unified latent representation.

Key responsibilities:
1. Processing multiple variable types (surface, atmospheric, agriculture/land etc., you name it);
2. Organizing spatial data into patches;
3. Creating embeddings that capture variable relationships;
4. Enriching the data with temporal and spatial context.

Key design patterns:
- Separate embedding layers for each variable type;
- Patch-based processing for spatial data;
- Position and time encodings;
- Perceiver IO for unified representation.

Architecture flow:
1. **Input Processing**
   - Takes variable groups:
     - surface [B, V, H, W]
     - atmospheric [B, V, L, H, W]  
    where B - batch size, V - num variables, H - height, W - width, and L - pressure levels (if any)
   - Processes each variable type separately
   - Converts spatial data into patches [num_patches, patch_dim]

2. **Embedding Process**
   - Variable-specific embeddings
   - Position encoding for spatial relationships
   - Time encoding for temporal context
   - Combines all embeddings into unified representation

3. **Latent Generation**
   - Uses Perceiver IO to create fixed-size latent representation
   - Maintains relationships between variables
   - Creates output of shape [B, num_latent_tokens, embed_dim]

The encoder creates a compact, unified representation that captures the relationships between different variables while preserving spatial and temporal context.

In [34]:
import torch
import torch.nn as nn
from einops import rearrange
from bfm_model.perceiver_core.perceiver_io import PerceiverIO

class SimpleBFMEncoder(nn.Module):
    def __init__(
        self,
        surface_vars: tuple[str, ...],
        atmos_vars: tuple[str, ...],
        atmos_levels: tuple[int, ...],
        patch_size: int = 4,
        embed_dim: int = 128,
        num_heads: int = 4,
        num_latent_tokens: int = 8,
        H: int = 152,
        W: int = 320,
    ):
        super().__init__()
        # some basic configuration variables
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.H = H
        self.W = W
        
        # names of the used variables
        self.surface_vars = surface_vars
        self.atmos_vars = atmos_vars
        self.atmos_levels = atmos_levels
        
        # we will use separate embeddings to encode each variable type
        self.surface_embed = nn.Linear(len(surface_vars) * patch_size * patch_size, embed_dim)
        self.atmos_embed = nn.Linear(len(atmos_vars) * patch_size * patch_size, embed_dim)
        
        # and universal position and time embeddings
        self.pos_embed = nn.Linear(2, embed_dim)
        self.time_embed = nn.Linear(1, embed_dim)
        
        # initialize Perceiver IO - the core of the encoder
        num_patches = (H // patch_size) * (W // patch_size)
        total_latents = num_latent_tokens
        
        self.perceiver = PerceiverIO(
            num_layers=2,
            dim=embed_dim,
            queries_dim=embed_dim,
            logits_dimension=None,
            num_latent_tokens=total_latents,
            latent_dimension=embed_dim,
            cross_attention_heads=num_heads,
            latent_attention_heads=num_heads,
            cross_attention_head_dim=embed_dim // num_heads,
            latent_attention_head_dim=embed_dim // num_heads,
            sequence_dropout_prob=0.1,
            num_input_axes=1
        )
        
        # also the latent queries - they will be used to encode the data to the format we want
        self.latents = nn.Parameter(torch.randn(total_latents, embed_dim))
        
        self.pre_perceiver_norm = nn.LayerNorm(embed_dim)
        self.pos_drop = nn.Dropout(p=0.1)

    def process_variables(self, variables: dict, embed_layer: nn.Module, name: str):
        """Process a group of variables"""
        if not variables:
            print(f"\n--- Processing {name}: No variables found ---")
            return None
            
        # stack variables [V, B, H, W]
        x = torch.stack(list(variables.values()), dim=0)
        print(f"\n--- Processing {name} ---")
        print(f"Variables in group: {list(variables.keys())}")
        print(f"Initial shape: {x.shape}")
        print(f"  - num_variables (V): {x.shape[0]}")
        print(f"  - batch_size (B): {x.shape[1]}")
        print(f"  - height (H): {x.shape[2]}")
        print(f"  - width (W): {x.shape[3]}")
        
        # now we go over the first (and only) batch
        x = x[:, 0]  # [V, H, W]
        print(f"\nAfter selecting first batch:")
        print(f"Shape: {x.shape}")
        print(f"  - num_variables (V): {x.shape[0]}")
        print(f"  - height (H): {x.shape[1]}")
        print(f"  - width (W): {x.shape[2]}")
        
        # reshape to patches - important step, as it allows us to handle the data in smaller chunks
        x = rearrange(x, 'v (h p1) (w p2) -> (h w) (v p1 p2)', p1=self.patch_size, p2=self.patch_size)
        num_patches_h = x.shape[0] // (self.W // self.patch_size)
        print(f"\nAfter patch reshape:")
        print(f"Shape: {x.shape}")
        print(f"  - num_patches (h*w): {x.shape[0]} ({num_patches_h}×{x.shape[0]//num_patches_h})")
        print(f"  - patch_dim (v*p1*p2): {x.shape[1]}")
        print(f"  - patch_size: {self.patch_size}")
        
        # embed - as simple as it gets (:
        x = embed_layer(x)
        print(f"\nAfter embedding:")
        print(f"Shape: {x.shape}")
        print(f"  - num_patches: {x.shape[0]}")
        print(f"  - embed_dim: {x.shape[1]}")
        
        return x

    def forward(self, batch, lead_time):
        device = next(self.parameters()).device
        
        # surface variables
        surface_embed = self.process_variables(
            batch.surface_variables, 
            self.surface_embed,
            "Surface Variables"
        )
        
        # atmospheric variables per level
        atmos_embeds = []
        for level_idx, level in enumerate(self.atmos_levels):
            level_vars = {k: v[:, level_idx] for k, v in batch.atmospheric_variables.items()}
            level_embed = self.process_variables(
                level_vars,
                self.atmos_embed,
                f"Atmospheric Level {level}"
            )
            if level_embed is not None:
                atmos_embeds.append(level_embed)
        
        # combine the embeddings
        embeddings = []
        if surface_embed is not None:
            embeddings.append(surface_embed)
        if atmos_embeds:
            atmos_embed = torch.cat(atmos_embeds, dim=0)
            embeddings.append(atmos_embed)
        
        # concat all embeddings
        x = torch.cat(embeddings, dim=0)
        x = x.unsqueeze(0)  # add batch dimension (not entirely necessary, as other steps here, but good practice is good practice)
        print(f"\nCombined embeddings shape: {x.shape}")
        
        # make the positions grid
        pos = torch.stack(
            torch.meshgrid(
                batch.batch_metadata.latitudes[::self.patch_size],
                batch.batch_metadata.longitudes[::self.patch_size],
                indexing="ij"
            ),
            dim=-1
        )
        
        # we flatten the positions to a single axis - less dimensions, less headache
        pos = pos.reshape(-1, 2)  # shape: [num_patches, 2]
        
        # normalize positions to [-1, 1], for numerical stability
        pos = 2 * (pos - pos.min(dim=0)[0]) / (pos.max(dim=0)[0] - pos.min(dim=0)[0]) - 1
        
        # again, good practice
        pos = pos.unsqueeze(0)  # shape: [1, num_patches, 2]
        
        # embedding the positions
        pos_embed = self.pos_embed(pos)  # shape: [1, num_patches, embed_dim]
        
        # expand the position embeddings for all variable groups
        num_var_groups = x.shape[1] // pos_embed.shape[1]
        pos_embed = pos_embed.repeat_interleave(num_var_groups, dim=1)
        
        # add the position encodings to the data
        x = x + pos_embed
        
        # aand the time encoding
        time = torch.tensor([[lead_time.total_seconds() / 3600]], device=device)
        time_embed = self.time_embed(time)
        x = x + time_embed.unsqueeze(1)
        
        # lastly, normalize and apply Perceiver IO to get our final embeddings
        x = self.pre_perceiver_norm(x)
        latents = self.latents.unsqueeze(0)
        x = self.perceiver(x, queries=latents)
        
        return self.pos_drop(x)

In [35]:
def create_sample_batch(batch_size=1, H=152, W=320):
    """
    Create a sample batch with the same structure as the real data but simplified.
    
    Returns:
        SimpleBatch with:
        - surface_variables: [B, H, W]
        - atmospheric_variables: [B, L, H, W]
        where:
        B = batch_size (default 1)
        L = pressure levels (2)
        H = height (152)
        W = width (320)
    """
    # make metadata
    metadata = SimpleBatchMetadata(
        latitudes=torch.linspace(72.0, 34.0, H),  # fun fact: I have no idea where this area would be
        longitudes=torch.linspace(0.0, 359.75, W),
        pressure_levels=(1000, 50),
        timestamp=(
            datetime(2024, 1, 1, 0, 0),
            datetime(2024, 1, 2, 0, 0)
        )
    )
    
    # surface variables (t2m, msl)
    surface_vars = {
        "t2m": torch.randn(batch_size, H, W),  # temperature
        "msl": torch.randn(batch_size, H, W)   # mean sea level pressure
    }
    
    # and some atmospheric variables (z, u, v)
    atmos_vars = {
        "z": torch.randn(batch_size, len(metadata.pressure_levels), H, W),
        "u": torch.randn(batch_size, len(metadata.pressure_levels), H, W),
        "v": torch.randn(batch_size, len(metadata.pressure_levels), H, W)
    }
    
    return SimpleBatch(
        surface_variables=surface_vars,
        atmospheric_variables=atmos_vars,
        batch_metadata=metadata
    )

In [36]:
def main():
    encoder = SimpleBFMEncoder(
        surface_vars=("t2m", "msl"),
        atmos_vars=("z", "u", "v"),
        atmos_levels=(1000, 50),
        patch_size=4,
        embed_dim=128,
        num_heads=4,
        num_latent_tokens=8,
        H=152,
        W=320
    )
    batch = create_sample_batch()
    print("\nProcessing sample batch...")
    lead_time = timedelta(hours=24)  # that means we want to predict 24 hours ahead
    
    try:
        output = encoder(batch, lead_time)
        print("\nSuccess!")
        print(f"Input shapes:")
        print(f"Surface variables: {next(iter(batch.surface_variables.values())).shape}")
        print(f"Atmospheric variables: {next(iter(batch.atmospheric_variables.values())).shape}")
        print(f"Output shape: {output.shape}")
        print(f"Expected shape: [batch_size=1, num_latent_tokens=8, embed_dim=128]")
        
    except Exception as e:
        print(f"\nError processing batch")
        print(f"Error message: {str(e)}")
        raise e

if __name__ == "__main__":
    main()


Processing sample batch...

--- Processing Surface Variables ---
Variables in group: ['t2m', 'msl']
Initial shape: torch.Size([2, 1, 152, 320])
  - num_variables (V): 2
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([2, 152, 320])
  - num_variables (V): 2
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 32])
  - num_patches (h*w): 3040 (38×80)
  - patch_dim (v*p1*p2): 32
  - patch_size: 4

After embedding:
Shape: torch.Size([3040, 128])
  - num_patches: 3040
  - embed_dim: 128

--- Processing Atmospheric Level 1000 ---
Variables in group: ['z', 'u', 'v']
Initial shape: torch.Size([3, 1, 152, 320])
  - num_variables (V): 3
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([3, 152, 320])
  - num_variables (V): 3
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 48])
  - num_patches (h*w): 3040 (3

You saw what happened? Let's summarize what *happened* here:  
ur data consists of two main variable types:

1. **Surface Variables** [B, H, W]  
   - `t2m`: 2-meter temperature  
   - `msl`: mean sea level pressure  

2. **Atmospheric Variables** [B, L, H, W]
   - `z`: geopotential
   - `u`: u-wind component
   - `v`: v-wind component
   - Measured at pressure levels: 1000hPa and 50hPa

Where:
- `B`: batch size (1)
- `L`: pressure levels (2)
- `H`: height (152)
- `W`: width (320)
- Spatial coverage: 72N to 34N, 0 to 359.75E
- Patch size: 4 x 4

Each variable is processed into patches, embedded, and then combined before being passed through a Perceiver IO architecture. That would gives us our compacted an enriched with information representation of the data.


## <span style="color: orange;">BFM Decoder Analysis</span>

The decoder is responsible for:
1. Converting latent representations back to physical space;
2. Reconstructing different variable types from the encoded embeddings;
3. Maintaining spatial relationships through position encodings;
4. Preserving temporal context in reconstructions.

Key design patterns:
- Token projections for each variable type;
- Inverse position encoding to maintain spatial structure;
- Query-based decoding through Perceiver IO;
- Variable-specific output heads;
- Structured output format matching input variables.

Architecture flow:
1. **Input Processing**
   - Takes encoded latent representation [B, L, D];
   - Creates variable-specific queries;
   - Applies position and time encodings to queries.

2. **Decoding Process**
   - Uses Perceiver IO for cross-attention between latents and queries;
   - Maintains separate processing paths for each variable type;
   - Preserves spatial dimensions through token projections.

3. **Output Generation**
   - Reconstructs surface variables [B, V, H, W];
   - Reconstructs atmospheric variables [B, V, L, H, W];
   - Maintains variable relationships and physical constraints.

The decoder mirrors the encoder's structure but in reverse, ensuring that the output matches the input format while preserving learned representations.

In [37]:
class SimpleBFMDecoder(nn.Module):
    def __init__(
        self,
        surface_vars: tuple[str, ...],
        atmos_vars: tuple[str, ...],
        atmos_levels: tuple[int, ...],
        patch_size: int = 4,
        embed_dim: int = 128,
        num_heads: int = 4,
        H: int = 152,
        W: int = 320,
    ):
        super().__init__()
         # again, some basic configuration variables
        self.embed_dim = embed_dim
        self.H = H
        self.W = W
        self.patch_size = patch_size

        # names of the used variables
        self.surface_vars = surface_vars
        self.atmos_vars = atmos_vars
        self.atmos_levels = atmos_levels
        
        # calculate the number of patches in the grid
        self.num_patches_h = H // patch_size
        self.num_patches_w = W // patch_size
        self.num_patches = self.num_patches_h * self.num_patches_w
        
        # position and time embeddings, just like in the encoder
        self.pos_embed = nn.Linear(2, embed_dim)
        self.time_embed = nn.Linear(1, embed_dim)
        
        # output projections for each variable type
        self.surface_proj = nn.Linear(embed_dim, len(surface_vars))
        self.atmos_proj = nn.Linear(embed_dim, len(atmos_vars))
        
        # initialize Perceiver IO
        total_queries = self.num_patches  # one query per patch - we want to predict each patch, and we have one query per patch
        
        self.perceiver = PerceiverIO(
            num_layers=2,
            dim=embed_dim,
            queries_dim=embed_dim,
            logits_dimension=None,
            num_latent_tokens=total_queries,
            latent_dimension=embed_dim,
            cross_attention_heads=num_heads,
            latent_attention_heads=num_heads,
            cross_attention_head_dim=embed_dim // num_heads,
            latent_attention_head_dim=embed_dim // num_heads,
            sequence_dropout_prob=0.1,
            num_input_axes=1
        )
        
        self.pos_drop = nn.Dropout(p=0.1)
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        # this method is also used in the original decoder - it basically just initializes the weights with a truncated normal distribution
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, batch, lead_time):
        B = x.shape[0] 
        device = x.device
        
        print(f"\n--- Decoder Forward Pass ---")
        print(f"Input shape: {x.shape}")
        print(f"Grid dimensions (H×W): {self.H}×{self.W}")
        print(f"Patch dimensions: {self.num_patches_h}×{self.num_patches_w}")
        
        # make the positions grid, as in the encoder
        pos = torch.stack(
            torch.meshgrid(
                torch.linspace(-1, 1, self.num_patches_h, device=device),
                torch.linspace(-1, 1, self.num_patches_w, device=device),
                indexing="ij"
            ),
            dim=-1
        ).reshape(-1, 2)
        
        # make queries with position encoding
        queries = self.pos_embed(pos)  # shape: [num_patches, embed_dim]
        queries = queries.unsqueeze(0).expand(B, -1, -1)  # shape: [B, num_patches, embed_dim]
        
        # add the time encoding
        time = torch.tensor([[lead_time.total_seconds() / 3600]], device=device)
        time_embed = self.time_embed(time)
        queries = queries + time_embed.unsqueeze(1)
        
        # apply Perceiver IO
        decoded = self.perceiver(x, queries=queries)
        print(f"Decoded patches shape: {decoded.shape}")
        
        # process surface variables
        surface_out = self.surface_proj(decoded)  # shape: [B, num_patches, num_surface_vars]
        surface_out = surface_out.permute(0, 2, 1)  # shape: [B, num_surface_vars, num_patches]
        surface_out = surface_out.view(B, len(self.surface_vars), 
                                     self.num_patches_h, self.num_patches_w)
        surface_out = F.interpolate(surface_out, size=(self.H, self.W), 
                                  mode='bilinear', align_corners=False)
        
        # process atmospheric variables
        atmos_out = self.atmos_proj(decoded)  # shape: [B, num_patches, num_atmos_vars]
        atmos_out = atmos_out.permute(0, 2, 1)  # shape: [B, num_atmos_vars, num_patches]
        atmos_out = atmos_out.view(B, len(self.atmos_vars), 
                                  self.num_patches_h, self.num_patches_w)
        atmos_out = F.interpolate(atmos_out, size=(self.H, self.W), 
                                 mode='bilinear', align_corners=False)
        
        # expand atmospheric variables for each level
        atmos_out = atmos_out.unsqueeze(2).expand(-1, -1, len(self.atmos_levels), -1, -1)
        
        print(f"\nOutput shapes:")
        print(f"Surface variables: {surface_out.shape}")
        print(f"Atmospheric variables: {atmos_out.shape}")
        
        return {
            "surface_variables": {
                name: surface_out[:, i] 
                for i, name in enumerate(self.surface_vars)
            },
            "atmospheric_variables": {
                name: atmos_out[:, i] 
                for i, name in enumerate(self.atmos_vars)
            }
        }

In [38]:
def test_decoder():
    # Create encoder and decoder
    encoder = SimpleBFMEncoder(
        surface_vars=("t2m", "msl"),
        atmos_vars=("z", "u", "v"),
        atmos_levels=(1000, 50),
        patch_size=4,
        embed_dim=128,
        num_heads=4,
        num_latent_tokens=8,
        H=152,
        W=320
    )
    
    decoder = SimpleBFMDecoder(
        surface_vars=("t2m", "msl"),
        atmos_vars=("z", "u", "v"),
        atmos_levels=(1000, 50),
        patch_size=4,
        embed_dim=128,
        num_heads=4,
        H=152,
        W=320
    )
    
    # Create sample batch
    batch = create_sample_batch()
    lead_time = timedelta(hours=24)
    
    # Process through encoder and decoder
    try:
        encoded = encoder(batch, lead_time)
        print("\nEncoded shape:", encoded.shape)
        
        decoded = decoder(encoded, batch, lead_time)
        print("\nDecoded outputs:")
        for var_type, vars in decoded.items():
            print(f"{var_type}:")
            for var_name, var_data in vars.items():
                print(f"  - {var_name}: {var_data.shape}")
                
    except Exception as e:
        print(f"\nError in processing:")
        print(f"Error message: {str(e)}")
        raise e

if __name__ == "__main__":
    test_decoder()


--- Processing Surface Variables ---
Variables in group: ['t2m', 'msl']
Initial shape: torch.Size([2, 1, 152, 320])
  - num_variables (V): 2
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([2, 152, 320])
  - num_variables (V): 2
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 32])
  - num_patches (h*w): 3040 (38×80)
  - patch_dim (v*p1*p2): 32
  - patch_size: 4

After embedding:
Shape: torch.Size([3040, 128])
  - num_patches: 3040
  - embed_dim: 128

--- Processing Atmospheric Level 1000 ---
Variables in group: ['z', 'u', 'v']
Initial shape: torch.Size([3, 1, 152, 320])
  - num_variables (V): 3
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([3, 152, 320])
  - num_variables (V): 3
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 48])
  - num_patches (h*w): 3040 (38×80)
  - patch_dim (v*p1*p2

## <span style="color: orange;">Biodiversity Foundation Model</span>
Below, we will create a simplified version of the BFM, which will be used to process our data and demonstrate how its bigger brother works.

In [39]:
import torch
import torch.nn as nn
from datetime import timedelta

class SimpleBFM(nn.Module):
    def __init__(
        self,
        surface_vars: tuple[str, ...],
        atmos_vars: tuple[str, ...],
        atmos_levels: tuple[int, ...],
        H: int = 152,
        W: int = 320,
        embed_dim: int = 128,
        num_latent_tokens: int = 8,
        patch_size: int = 4,
        num_heads: int = 4,
    ):
        super().__init__()
        
        # initialize encoder
        self.encoder = SimpleBFMEncoder(
            surface_vars=surface_vars,
            atmos_vars=atmos_vars,
            atmos_levels=atmos_levels,
            patch_size=patch_size,
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_latent_tokens=num_latent_tokens,
            H=H,
            W=W,
        )
        
        # simplified backbone (identity function)
        # Note: In the full BFM, this would be either MViT or Swin Transformer
        # both maintain input dimensions (i.e., shapes and sizes), so for demonstration we use identity
        self.backbone = nn.Identity()
        
        # initialize decoder
        self.decoder = SimpleBFMDecoder(
            surface_vars=surface_vars,
            atmos_vars=atmos_vars,
            atmos_levels=atmos_levels,
            patch_size=patch_size,
            embed_dim=embed_dim,
            num_heads=num_heads,
            H=H,
            W=W,
        )

    def forward(self, batch, lead_time):
        # encode
        encoded = self.encoder(batch, lead_time)
        
        # process through backbone
        # in full BFM, this would do temporal/spatial processing
        # both MViT and Swin maintain shape: [B, num_tokens, embed_dim]
        processed = self.backbone(encoded)
        
        # decode
        output = self.decoder(processed, batch, lead_time)
        
        return output


def test_simple_bfm():
    model = SimpleBFM(
        surface_vars=("t2m", "msl"),
        atmos_vars=("z", "u", "v"),
        atmos_levels=(1000, 50),
    )
    
    # create sample batch
    batch = create_sample_batch()
    lead_time = timedelta(hours=24)
    
    # process batch
    try:
        output = model(batch, lead_time)
        
        print("\nModel outputs:")
        for var_type, vars in output.items():
            print(f"\n{var_type}:")
            for var_name, var_data in vars.items():
                print(f"  - {var_name}: {var_data.shape}")
                
    except Exception as e:
        print(f"\nError in processing:")
        print(f"Error message: {str(e)}")
        raise e

if __name__ == "__main__":
    test_simple_bfm()


--- Processing Surface Variables ---
Variables in group: ['t2m', 'msl']
Initial shape: torch.Size([2, 1, 152, 320])
  - num_variables (V): 2
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([2, 152, 320])
  - num_variables (V): 2
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 32])
  - num_patches (h*w): 3040 (38×80)
  - patch_dim (v*p1*p2): 32
  - patch_size: 4

After embedding:
Shape: torch.Size([3040, 128])
  - num_patches: 3040
  - embed_dim: 128

--- Processing Atmospheric Level 1000 ---
Variables in group: ['z', 'u', 'v']
Initial shape: torch.Size([3, 1, 152, 320])
  - num_variables (V): 3
  - batch_size (B): 1
  - height (H): 152
  - width (W): 320

After selecting first batch:
Shape: torch.Size([3, 152, 320])
  - num_variables (V): 3
  - height (H): 152
  - width (W): 320

After patch reshape:
Shape: torch.Size([3040, 48])
  - num_patches (h*w): 3040 (38×80)
  - patch_dim (v*p1*p2

# <span style="color: orange;">Adapting BFM Architecture for Air Quality Prediction</span>

In this section, we'll explore how the BFM architecture can be adapted for different types of environmental data. We'll use air quality prediction as an example, demonstrating the flexibility of our encoder-decoder architecture.

Key differences from climate prediction, in the current context:
1. **Data structure**
   - Climate: Spatial grids (H x W) with multiple variable types
   - Air Quality: Time series data with sensor readings and ground truth

2. **Variable types**
   - Climate: Surface, atmospheric, species, land variables
   - Air Quality: Sensor readings (PT08.S*), ground truth measurements (*GT), physical parameters

3. **Temporal aspect**
   - Climate: Predicting spatial patterns at future timesteps
   - Air Quality: Predicting sensor values for the next hour based on historical readings

We'll create a simplified version of the Air Quality Foundation Model (AQFM) to demonstrate how the core architecture can be adapted while maintaining the same conceptual structure of encoder → backbone → decoder.

## <span style="color: orange;">AQFM Encoder</span>

In [40]:
class SimpleAQEncoder(nn.Module):
    def __init__(
        self, 
        feature_names, 
        embed_dim, 
        num_latent_tokens, 
        max_history_size,
        num_heads=8,
        head_dim=64,
        depth=2,
        drop_rate=0.1
    ):
        super().__init__()
        self.feature_names = feature_names
        self.embed_dim = embed_dim
        self.num_latent_tokens = num_latent_tokens
        self.max_history_size = max_history_size
        
        # create embeddings for each feature group
        self.sensor_embed = nn.Linear(len(feature_names['sensor']), embed_dim)
        self.ground_truth_embed = nn.Linear(len(feature_names['ground_truth']), embed_dim)
        self.physical_embed = nn.Linear(len(feature_names['physical']), embed_dim)
        
        # initialize latent queries
        self.latents = nn.Parameter(torch.randn(num_latent_tokens, embed_dim))
        
        # add perceiver io
        self.perceiver = PerceiverIO(
            num_layers=depth,
            dim=embed_dim,
            queries_dim=embed_dim,
            logits_dimension=None,
            num_latent_tokens=num_latent_tokens,
            latent_dimension=embed_dim,
            cross_attention_heads=num_heads,
            latent_attention_heads=num_heads,
            cross_attention_head_dim=head_dim,
            latent_attention_head_dim=head_dim,
            sequence_dropout_prob=drop_rate,
            num_fourier_bands=32,
            max_frequency=max_history_size,
            num_input_axes=1,
            position_encoding_type="fourier"
        )
        
        self.pos_drop = nn.Dropout(p=drop_rate)
        
    def forward(self, batch, lead_time):
        # process sensor variables
        sensor_vars = torch.stack([batch.sensor_vars[name] for name in self.feature_names['sensor']], dim=-1)
        sensor_embed = self.sensor_embed(sensor_vars)  # [b, t, embed_dim]
        
        # process ground truth variables
        ground_truth_vars = torch.stack([batch.ground_truth_vars[name] for name in self.feature_names['ground_truth']], dim=-1)
        ground_truth_embed = self.ground_truth_embed(ground_truth_vars)  # [b, t, embed_dim]
        
        # process physical variables
        physical_vars = torch.stack([batch.physical_vars[name] for name in self.feature_names['physical']], dim=-1)
        physical_embed = self.physical_embed(physical_vars)  # [b, t, embed_dim]
        
        # combine embeddings
        combined_embed = sensor_embed + ground_truth_embed + physical_embed  # [b, t, embed_dim]
        
        # get batch size
        B = combined_embed.size(0) if len(combined_embed.size()) > 2 else 1
        
        # prepare latent queries
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)  # [b, num_latent_tokens, embed_dim]
        
        # process through perceiver io
        encoded = self.perceiver(combined_embed, queries=latents)
        encoded = self.pos_drop(encoded)
        
        return encoded

The transition from BFM's spatial climate data to AQFM's temporal air quality data required several fundamental changes in the encoder architecture. Let's explore these adaptations in detail:

### 1. Data Structure and Variable Handling
In the BFM, we dealt with spatial grids where each variable was represented as a 2D map (H x W). This required patch-based processing, where the encoder would divide these spatial maps into smaller patches (patch_size x patch_size) before embedding. The AQFM, however, handles time series data where each variable is a 1D sequence of measurements over time. This fundamental difference eliminated the need for patch-based processing, simplifying our embedding approach.

Instead of the BFM's separate handling of surface and atmospheric variables at different pressure levels (in the simplified version, and many others in the original), the AQFM groups variables into three categories:
- Sensor readings (PT08.S* series)
- Ground truth measurements (*GT series)
- Physical parameters (temperature, relative humidity, absolute humidity)

### 2. Embedding Strategy
The BFM used patch-based embeddings:
```python
self.surface_embed = nn.Linear(len(surface_vars) * patch_size * patch_size, embed_dim)
```
which processed spatial patches of each variable type. In contrast, the AQFM uses direct linear projections:
```python
self.sensor_embed = nn.Linear(len(feature_names['sensor']), embed_dim)
```
This simpler approach is possible because we're dealing with 1D temporal sequences rather than 2D spatial grids.

### 3. Position Encoding
The BFM needed to encode both latitude and longitude positions:
```python
self.pos_embed = nn.Linear(2, embed_dim) # lat, lon
```
and created a complex position grid using meshgrid. The AQFM's position encoding is handled entirely by the Perceiver IO's built-in Fourier features, which are better suited for temporal sequences. This is why we specify:
```python
position_encoding_type="fourier",
num_fourier_bands=32,
max_frequency=max_history_size,
num_input_axes=1 # temporal dimension only
```
### 4. Variable Combination
In the BFM, variables were processed separately and then concatenated:
```python
embeddings = []
if surface_embed is not None:
  embeddings.append(surface_embed)
if atmos_embeds:
  embeddings.append(torch.cat(atmos_embeds, dim=0))
x = torch.cat(embeddings, dim=0)
```
The AQFM uses addition instead:
```python
combined_embed = sensor_embed + ground_truth_embed + physical_embed
```

This additive approach helps maintain the temporal alignment of different variable types and allows for better interaction between different measurement types at the same timestep.

### 5. Perceiver IO Configuration
While both versions use Perceiver IO as their core processing component, the configurations differ:
- BFM focused on spatial relationships with attention heads processing patches + the patches' various version at different time steps
- AQFM emphasizes temporal relationships with attention heads processing time steps
- Both maintain the concept of latent tokens, but they represent different things:
  * BFM: spatial-temporal patterns across the globe
  * AQFM: temporal patterns and variable interactions in the air quality measurements (in one singular location, if one may wish to think of it that way)

## <span style="color: orange;">AQFM Decoder</span>

In [46]:
class SimpleAQDecoder(nn.Module):
    def __init__(self, feature_names, embed_dim, num_heads=8, head_dim=64, depth=2, drop_rate=0.1):
        super().__init__()
        self.feature_names = feature_names
        self.embed_dim = embed_dim
        
        # create clean target names mapping
        self.target_names = {
            name: f"target_{i}" 
            for i, name in enumerate(feature_names["ground_truth"])
        }
        
        # initialize perceiver IO for decoding
        self.perceiver = PerceiverIO(
            num_layers=depth,
            dim=embed_dim,
            queries_dim=embed_dim,
            logits_dimension=None,
            num_latent_tokens=len(feature_names["ground_truth"]),  # one token per target
            latent_dimension=embed_dim,
            cross_attention_heads=num_heads,
            latent_attention_heads=num_heads,
            cross_attention_head_dim=head_dim,
            latent_attention_head_dim=head_dim,
            sequence_dropout_prob=drop_rate,
            num_fourier_bands=32,
            max_frequency=24,  # for 24 hours of history
            num_input_axes=1,
            position_encoding_type="fourier",
        )
        
        # prediction heads for each target variable
        self.prediction_heads = nn.ModuleDict({
            self.target_names[name]: nn.Sequential(
                nn.Linear(embed_dim, embed_dim // 2),
                nn.ReLU(),
                nn.Dropout(drop_rate),
                nn.Linear(embed_dim // 2, embed_dim // 4),
                nn.ReLU(),
                nn.Dropout(drop_rate),
                nn.Linear(embed_dim // 4, 1),  # single value prediction
            )
            for name in feature_names["ground_truth"]
        })
        
        # query tokens for each target variable
        self.query_tokens = nn.Parameter(
            torch.randn(len(feature_names["ground_truth"]), embed_dim)
        )
        
        # lead time embedding
        self.lead_time_embed = nn.Linear(1, embed_dim)
        self.pos_drop = nn.Dropout(p=drop_rate)
        
    def forward(self, x, batch, lead_time):
        B = x.shape[0]  # shape: [b, num_latent_tokens, embed_dim]
        device = x.device
        
        # prepare queries
        queries = repeat(self.query_tokens, "n d -> b n d", b=B)
        
        # add lead time information to queries
        lead_hours = torch.tensor(
            [[lead_time.total_seconds() / 3600]], 
            device=device, 
            dtype=torch.float
        )
        lead_time_encoding = self.lead_time_embed(lead_hours)
        queries = queries + lead_time_encoding.unsqueeze(1)
        
        # process through perceiver IO
        decoded = self.perceiver(x, queries=queries)
        decoded = self.pos_drop(decoded)
        
        # generate predictions for each target variable
        predictions = {}
        for idx, (orig_name, safe_name) in enumerate(self.target_names.items()):
            token_embedding = decoded[:, idx]  # shape: [B, embed_dim]
            pred = self.prediction_heads[safe_name](token_embedding)  # shape: [B, 1]
            predictions[orig_name] = pred
            
        return predictions  # return only ground truth predictions

As you can see, the transition from BFM's spatial-temporal predictions to AQFM's temporal air quality forecasting required quite a dozen of changes in the decoder architecture. Let's analyze these adaptations and what could one do to adapt their own decoder to some new data formats:

### 1. Output Structure and Prediction Targets
The BFM decoder needed to reconstruct complete spatial fields:
- Surface variables (2D fields: [B, V, H, W]);
- Atmospheric variables (3D fields: [B, V, L, H, W]);

In contrast, the AQFM decoder focuses on point predictions:
- Only ground truth variables;
- Single value per variable: [B, 1];
- No spatial reconstruction needed;

This fundamental difference led to much simpler output projections in AQFM:
```python
# BFM's spatial projections
self.surface_proj = nn.Linear(embed_dim, len(surface_vars))
self.atmos_proj = nn.Linear(embed_dim, len(atmos_vars))

# AQFM's prediction heads
self.prediction_heads = nn.ModuleDict({
    self.target_names[name]: nn.Sequential(
        nn.Linear(embed_dim, embed_dim // 2),
        nn.ReLU(),
        nn.Dropout(drop_rate),
        nn.Linear(embed_dim // 2, embed_dim // 4),
        nn.ReLU(),
        nn.Dropout(drop_rate),
        nn.Linear(embed_dim // 4, 1),
    )
    for name in feature_names["ground_truth"]
})
```
<span style="color: orange;">**NOTE**</span>: We can actually leverage Perceiver IO's query system directly for predictions, as demonstrated in `perceiver_io_prediction.py`. Here's how:

1. Instead of using prediction heads, we can define learnable query tokens:
```python
self.query = nn.Parameter(torch.randn(1, queries_dim))
```

2. Configure Perceiver IO to output the exact dimensions we need:
```python
self.perceiver_io = PerceiverIO(
    # ... other params ...
    logits_dimension=1,  # Direct prediction output
    queries_dim=queries_dim,
    # ... other params ...
)
```

3. Use the query in the forward pass:
```python
def forward(self, x):
    batch_size = x.shape[0]
    queries = self.query.expand(batch_size, 1, -1)
    return self.perceiver_io(x, queries=queries)
```

### 2. Query System
Both decoders use queries to extract information from the encoded representation, but with different purposes:

BFM:
- One query per spatial patch;
- Position-based queries using lat/lon coordinates;
- Queries aim to reconstruct spatial patterns;

AQFM:
- One query per target variable;
- Learnable query tokens;
- Queries focus on temporal patterns specific to each variable;

### 3. Position and Time Encoding
BFM needed explicit spatial position encoding:
```python
pos = torch.stack(torch.meshgrid(
    torch.linspace(-1, 1, self.num_patches_h),
    torch.linspace(-1, 1, self.num_patches_w),
    indexing="ij"
), dim=-1)
```

AQFM uses:
- Fourier features for temporal encoding;
- Lead time embedding added to queries;
- No need for spatial position encoding;

### 4. Output Processing
BFM required complex output processing:
- Reshape from patches to full fields;
- Interpolation to original resolution;
- Handling multiple pressure levels;

AQFM's output processing is straightforward:
- Direct projection to scalar values;
- No spatial reconstruction;
- No need for interpolation;

### 5. Perceiver IO Configuration
While both use Perceiver IO, their configurations reflect different goals:

BFM:
- Queries based on spatial positions;
- Focus on spatial reconstruction;
- Simple cross-attention setup;

AQFM:
- Learnable query tokens;
- Emphasis on temporal patterns;
- More sophisticated prediction heads;

### 6. Initialization
AQFM uses:
- MLPs for each target;
- Separate processing paths for each variable;

BFM uses:
- Truncated normal initialization;
- Simple linear projections;
- Shared processing for variable types;

These adaptations reflect the fundamental shift from spatial field reconstruction to temporal point prediction, resulting in a more focused and efficient architecture for air quality forecasting.

## <span style="color: orange;">AQFM</span>

In [42]:
@dataclass
class AQMetadata:
    """Metadata for air quality data, just like climate data"""

    time: List[datetime]
    feature_names: Dict[str, List[str]]
    sequence_length: int
    prediction_horizon: int


@dataclass
class AQBatch:
    """Batch structure for air quality data, again just like climate data"""

    sensor_vars: Dict[str, torch.Tensor]  # PT08.S* sensor readings
    ground_truth_vars: Dict[str, torch.Tensor]  # ground truth ones (*GT)
    physical_vars: Dict[str, torch.Tensor]  # any other things like T, RH, AH
    metadata: AQMetadata

In [43]:
class SimpleAQFM(nn.Module):
    def __init__(
        self,
        feature_names: dict[str, list[str]],
        embed_dim: int = 512,  # increased to match original
        num_latent_tokens: int = 8,
        max_history_size: int = 24,
        backbone_type: Literal["swin", "mvit", "identity"] = "identity",
        # encoder params
        encoder_num_heads: int = 8,
        encoder_head_dim: int = 64,
        encoder_depth: int = 2,
        encoder_drop_rate: float = 0.1,
        # decoder params
        decoder_num_heads: int = 8,
        decoder_head_dim: int = 64,
        decoder_depth: int = 2,
        decoder_drop_rate: float = 0.1,
    ):
        super().__init__()
        
        # initialize encoder with more parameters
        self.encoder = SimpleAQEncoder(
            feature_names=feature_names,
            embed_dim=embed_dim,
            num_latent_tokens=num_latent_tokens,
            max_history_size=max_history_size,
            num_heads=encoder_num_heads,
            head_dim=encoder_head_dim,
            depth=encoder_depth,
            drop_rate=encoder_drop_rate,
        )
        
        # initialize backbone based on type
        if backbone_type == "identity":
            self.backbone = nn.Identity()
        elif backbone_type == "swin":
            pass # for simplicity, we don't use swin
        elif backbone_type == "mvit":
            pass # and also mvit
        
        # initialize decoder with more parameters
        self.decoder = SimpleAQDecoder(
            feature_names=feature_names,
            embed_dim=embed_dim,
            num_heads=decoder_num_heads,
            head_dim=decoder_head_dim,
            depth=decoder_depth,
            drop_rate=decoder_drop_rate,
        )
        
        self.backbone_type = backbone_type

    def forward(self, batch, lead_time):
        # encode
        encoded = self.encoder(batch, lead_time)
        
        # process through backbone
        if self.backbone_type in ["swin", "mvit"]:
            patch_shape = [self.encoder.num_latent_tokens, 1, 1]
            processed = self.backbone(encoded, lead_time=lead_time, 
                                   rollout_step=0, patch_shape=patch_shape)
        else:
            processed = self.backbone(encoded)
        
        # decode
        output = self.decoder(processed, batch, lead_time)
        
        return output

As for the whole model, where we integrate all the components into one class definition, it's pretty much the same as the SimpleBFM, only that we now have a different encoder and decoder. The same statement and general concepts extends to full implementations in `aqfm.py` and `bfm.py`.	


In [47]:
def create_sample_aq_batch():
    # define feature names
    feature_names = {
        "sensor": ["PT08.S1(CO)", "PT08.S2(NMHC)", "PT08.S3(NOx)", "PT08.S4(NO2)", "PT08.S5(O3)"],
        "ground_truth": ["CO(GT)", "NMHC(GT)", "C6H6(GT)", "NOx(GT)", "NO2(GT)"],
        "physical": ["T", "RH", "AH"],
    }
    
    # create random tensors for each feature group with batch dimension
    sensor_vars = {name: torch.rand(1, 24) for name in feature_names["sensor"]}  # shape: [B, T]
    ground_truth_vars = {name: torch.rand(1, 24) for name in feature_names["ground_truth"]}  # shape: [B, T]
    physical_vars = {name: torch.rand(1, 24) for name in feature_names["physical"]}  # shape: [B, T]
    
    # create metadata
    metadata = AQMetadata(
        time=[datetime.now() for _ in range(24)],
        feature_names=feature_names,
        sequence_length=24,
        prediction_horizon=1,
    )
    
    # create a batch
    batch = AQBatch(
        sensor_vars=sensor_vars,
        ground_truth_vars=ground_truth_vars,
        physical_vars=physical_vars,
        metadata=metadata
    )
    
    return batch

def test_simple_aqfm():
    # create model
    feature_names = {
        "sensor": ["PT08.S1(CO)", "PT08.S2(NMHC)", "PT08.S3(NOx)", "PT08.S4(NO2)", "PT08.S5(O3)"],
        "ground_truth": ["CO(GT)", "NMHC(GT)", "C6H6(GT)", "NOx(GT)", "NO2(GT)"],
        "physical": ["T", "RH", "AH"],
    }
    model = SimpleAQFM(feature_names=feature_names)
    
    # create sample batch
    batch = create_sample_aq_batch()
    lead_time = timedelta(hours=1)
    
    # process batch
    try:
        predictions = model(batch, lead_time)
        
        print("\nModel outputs:")
        for var_name, pred in predictions.items():
            print(f"  - {var_name}: {pred.shape}")
                
    except Exception as e:
        print(f"\nError in processing:")
        print(f"Error message: {str(e)}")
        raise e

test_simple_aqfm()


Model outputs:
  - CO(GT): torch.Size([1, 1])
  - NMHC(GT): torch.Size([1, 1])
  - C6H6(GT): torch.Size([1, 1])
  - NOx(GT): torch.Size([1, 1])
  - NO2(GT): torch.Size([1, 1])
